Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 49 additions & 19 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from torchao.prototype.parq.quant import (
Int4UnifTorchaoQuantizer,
LSBQuantizer,
Quantizer,
StretchedIntxWeightConfig,
StretchedUnifTorchaoQuantizer,
TernaryUnifQuantizer,
Expand Down Expand Up @@ -162,29 +161,43 @@ def build_param_groups(
model,
b: int = 2,
group_size: Optional[int] = None,
quantizer: Optional[Quantizer] = None,
embed_b: int = 4,
):
params_quant, params_embed, params_no_quant = split_param_groups(model)
quant_kwargs = {}
if group_size:
quant_kwargs["quant_block_size"] = group_size
if quantizer is not None:
quant_kwargs["quantizer"] = quantizer
param_groups = [
{"params": params_quant, "quant_bits": b, **quant_kwargs},
{"params": params_no_quant},
]
if params_embed:
param_groups.append(
{
"params": params_embed,
"quant_bits": 4,
"quantizer": UnifTorchaoQuantizer(),
}
)
param_groups.append({"params": params_embed, "quant_bits": embed_b})
return param_groups


def get_optim_kwargs(
model, base_optimizer, embedding=True, quant_cls=UnifTorchaoQuantizer
):
optim_kwargs = {}
if embedding:
embed_data_ptrs = set(
(
m.weight.data_ptr()
for m in model.modules()
if isinstance(m, nn.Embedding)
)
)
group_idx = -1
for i, group in enumerate(base_optimizer.param_groups):
if all(p.data_ptr() in embed_data_ptrs for p in group["params"]):
group_idx = i
break
assert group_idx > -1
optim_kwargs["group_quantizer_map"] = {group_idx: quant_cls()}
return optim_kwargs


def compare_quantized_models(
model: nn.Module,
m_ref: nn.Module,
Expand Down Expand Up @@ -222,7 +235,7 @@ def compare_parq_convert(
orig_model = copy.deepcopy(model) # save copy of PARQ quantized model

# equivalent to torchao's convert step
optimizer.torchao_convert(model, weight_only=weight_only)
optimizer.torchao_convert(model, weight_only=weight_only, embed_weight_only=True)

inputs = model.example_inputs(device=_DEVICE)
torch.testing.assert_close(model(inputs), orig_model(inputs))
Expand Down Expand Up @@ -290,15 +303,16 @@ def test_parq_train_loop(
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
else:
quantizer = LSBQuantizer()
param_groups = build_param_groups(
model, b, quantizer=quantizer if per_group_quantizer else None
)
param_groups = build_param_groups(model, b, embed_b=b)
base_optimizer = torch.optim.AdamW(param_groups)

prox_map = (
ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2)
)
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
optim_kwargs = get_optim_kwargs(
model, base_optimizer, quant_cls=type(quantizer), embedding=False
)
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map, **optim_kwargs)
for _ in range(3):
x = model.example_inputs(device=_DEVICE)
out = model(x)
Expand Down Expand Up @@ -367,11 +381,13 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):

b = 4
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False)
optimizer = QuantOptimizer(
base_optimizer,
Int4UnifTorchaoQuantizer(),
ProxHardQuant(),
quant_per_channel=True,
**optim_kwargs,
)
compare_parq_convert(model, m_ref, optimizer, weight_only=True)

Expand All @@ -387,11 +403,13 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
quantize_(m_ref, config)

base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False)
optimizer = QuantOptimizer(
base_optimizer,
UnifTorchaoQuantizer(),
ProxHardQuant(),
quant_per_channel=True,
**optim_kwargs,
)
compare_parq_convert(model, m_ref, optimizer, weight_only=True)
check_torchao_tensor_subclass(self, model, weight_only=True)
Expand Down Expand Up @@ -462,11 +480,13 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
quantize_(m_ref, config, filter_fn=_is_linear)

base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False)
optimizer = QuantOptimizer(
base_optimizer,
quantizer,
ProxHardQuant(),
quant_per_channel=True,
**optim_kwargs,
)
compare_parq_convert(model, m_ref, optimizer, weight_only=True)
check_torchao_tensor_subclass(self, model, weight_only=True)
Expand All @@ -482,14 +502,19 @@ def test_intx_weight_only_tied_embed_linear(

quantizer = StretchedUnifTorchaoQuantizer(b)
base_optimizer = torch.optim.SGD(build_param_groups(model, b))
optim_kwargs = get_optim_kwargs(model, base_optimizer)
optimizer = QuantOptimizer(
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
base_optimizer,
quantizer,
ProxHardQuant(),
quant_per_channel=True,
**optim_kwargs,
)
optimizer.zero_grad()
optimizer.step()

apply_activation_quantization(model, optimizer, model_dtype)
optimizer.torchao_convert(model)
optimizer.torchao_convert(model, embed_weight_only=True)
check_torchao_tensor_subclass(self, model)
self.assertTrue(
torch.equal(model.embed_tokens.weight.qdata, model.linear2.weight.qdata)
Expand Down Expand Up @@ -531,8 +556,13 @@ def test_int8_dynamic_activation_intx_e2e(

# quantize weights with PARQ
base_optimizer = torch.optim.SGD(build_param_groups(model, b, group_size))
optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False)
optimizer = QuantOptimizer(
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
base_optimizer,
quantizer,
ProxHardQuant(),
quant_per_channel=True,
**optim_kwargs,
)

optimizer.zero_grad()
Expand Down
32 changes: 26 additions & 6 deletions torchao/prototype/parq/optim/quantopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ def __init__(
quant_per_channel: bool = False,
quant_shrink: bool = False,
anneal_wd_frac: float = 0.0,
group_quantizer_map: Optional[dict[int, Quantizer]] = None,
) -> None:
if not 0 <= anneal_wd_frac <= 1:
raise ValueError(f"Invalid {anneal_wd_frac=} outside range [0.0, 1.0]")

# need to reconstruct these objects if loading checkpoint
self.base_optimizer = base_optimizer
self.quantizer = quantizer
self.group_quantizer_map = group_quantizer_map
self.prox_map = prox_map

# need to store these attributes in state_dict for checkpoint
Expand Down Expand Up @@ -153,20 +155,39 @@ def _filter_fn(module: nn.Module, *args, param_set) -> bool:
for param_set in self._param_sets():
yield partial(_filter_fn, param_set=param_set)

def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
def _get_quantizer(self, group_idx: int) -> Optional[Quantizer]:
if self.group_quantizer_map and group_idx in self.group_quantizer_map:
return self.group_quantizer_map[group_idx]
return self.quantizer

def torchao_convert(
self,
model: nn.Module,
weight_only: bool = False,
embed_weight_only: bool = False,
) -> None:
"""Converts model parameters to torchao quantized tensor subclasses."""
model.eval()
self.restore_latent_params()

# TODO(lvj): find more robust way to identify embedding layers
embed_data_ptrs = set()
linear_data_ptrs = set()
embed_modules = []
for module in model.modules():
if isinstance(module, nn.Embedding):
embed_modules.append(module)
embed_data_ptrs.add(module.weight.data_ptr())
elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs:
linear_data_ptrs.add(module.weight.data_ptr())

tied_embeddings = False
if not embed_weight_only and getattr(model, "_tied_weights_keys", None):
# Workaround for dynamic activations on tied embeddings
tied_embeddings = True
for module in embed_modules:
setattr(module, "bias", None)

filter_fns = []
configs = []
attach_hf_config = _is_hf_model(model)
Expand All @@ -175,7 +196,7 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
zip(self.regularized_param_groups(), self.get_filter_fns(model))
):
filter_fns.append(filter_fn)
quantizer = group.get("quantizer", self.quantizer)
quantizer = self._get_quantizer(i)
if not isinstance(quantizer, UnifTorchaoQuantizer) or not group["params"]:
configs.append(None)
continue
Expand All @@ -187,7 +208,7 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
any_embed = any(p.data_ptr() in embed_data_ptrs for p in group["params"])
config = _get_config_from_quantizer(
quantizer,
weight_only or any_embed,
weight_only or (any_embed and not tied_embeddings),
device,
group["quant_bits"],
group.get("quant_block_size"),
Expand Down Expand Up @@ -255,10 +276,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
else:
quant_update = False

for group in self.regularized_param_groups():
for i, group in enumerate(self.regularized_param_groups()):
# Override quantizer if specified in the group
quantizer = group.get("quantizer", self.quantizer)
assert isinstance(quantizer, Quantizer), f"Invalid {quantizer=}"
quantizer = self._get_quantizer(i)

# AProx in practice: ensure shrinkage coefficient >= 1
group["cumu_lr"] += group["lr"]
Expand Down
Loading