From 9c0a92b63a54e3761274545934a362c3541c64bc Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 30 Sep 2025 12:28:45 -0700 Subject: [PATCH 1/4] Replace param_group quantizer instance with QuantOptimizer attribute --- torchao/prototype/parq/optim/quantopt.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index bfa651dcc9..f78133eb03 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -56,6 +56,7 @@ def __init__( quant_per_channel: bool = False, quant_shrink: bool = False, anneal_wd_frac: float = 0.0, + group_quantizer_map: Optional[dict[int, Quantizer]] = None, ) -> None: if not 0 <= anneal_wd_frac <= 1: raise ValueError(f"Invalid {anneal_wd_frac=} outside range [0.0, 1.0]") @@ -63,6 +64,7 @@ def __init__( # need to reconstruct these objects if loading checkpoint self.base_optimizer = base_optimizer self.quantizer = quantizer + self.group_quantizer_map = group_quantizer_map self.prox_map = prox_map # need to store these attributes in state_dict for checkpoint @@ -153,6 +155,11 @@ def _filter_fn(module: nn.Module, *args, param_set) -> bool: for param_set in self._param_sets(): yield partial(_filter_fn, param_set=param_set) + def _get_quantizer(self, group_idx: int) -> Optional[Quantizer]: + if self.group_quantizer_map and group_idx in self.group_quantizer_map: + return self.group_quantizer_map[group_idx] + return self.quantizer + def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: """Converts model parameters to torchao quantized tensor subclasses.""" model.eval() @@ -175,7 +182,7 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: zip(self.regularized_param_groups(), self.get_filter_fns(model)) ): filter_fns.append(filter_fn) - quantizer = group.get("quantizer", self.quantizer) + quantizer = self._get_quantizer(i) if not isinstance(quantizer, UnifTorchaoQuantizer) or not group["params"]: configs.append(None) continue @@ -255,10 +262,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] else: quant_update = False - for group in self.regularized_param_groups(): + for i, group in enumerate(self.regularized_param_groups()): # Override quantizer if specified in the group - quantizer = group.get("quantizer", self.quantizer) - assert isinstance(quantizer, Quantizer), f"Invalid {quantizer=}" + quantizer = self._get_quantizer(i) # AProx in practice: ensure shrinkage coefficient >= 1 group["cumu_lr"] += group["lr"] From f07488e7cfe50d5fd4b84ae092d1d91ca275bf75 Mon Sep 17 00:00:00 2001 From: lisjin Date: Wed, 1 Oct 2025 10:06:18 -0700 Subject: [PATCH 2/4] Fix broken test_parq.py --- test/prototype/test_parq.py | 48 ++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 10004a03f9..e9ccdb1e99 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -21,7 +21,6 @@ from torchao.prototype.parq.quant import ( Int4UnifTorchaoQuantizer, LSBQuantizer, - Quantizer, StretchedIntxWeightConfig, StretchedUnifTorchaoQuantizer, TernaryUnifQuantizer, @@ -162,29 +161,29 @@ def build_param_groups( model, b: int = 2, group_size: Optional[int] = None, - quantizer: Optional[Quantizer] = None, ): params_quant, params_embed, params_no_quant = split_param_groups(model) quant_kwargs = {} if group_size: quant_kwargs["quant_block_size"] = group_size - if quantizer is not None: - quant_kwargs["quantizer"] = quantizer param_groups = [ {"params": params_quant, "quant_bits": b, **quant_kwargs}, {"params": params_no_quant}, ] if params_embed: - param_groups.append( - { - "params": params_embed, - "quant_bits": 4, - "quantizer": UnifTorchaoQuantizer(), - } - ) + param_groups.append({"params": params_embed, "quant_bits": 4}) return param_groups +def get_optim_kwargs(base_optimizer, embedding=True, quant_cls=UnifTorchaoQuantizer): + optim_kwargs = {} + if embedding: + group_idx = len(base_optimizer.param_groups) - 2 + assert group_idx > -1 + optim_kwargs["group_quantizer_map"] = {group_idx: quant_cls()} + return optim_kwargs + + def compare_quantized_models( model: nn.Module, m_ref: nn.Module, @@ -290,15 +289,14 @@ def test_parq_train_loop( quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer() else: quantizer = LSBQuantizer() - param_groups = build_param_groups( - model, b, quantizer=quantizer if per_group_quantizer else None - ) + param_groups = build_param_groups(model, b) base_optimizer = torch.optim.AdamW(param_groups) prox_map = ( ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2) ) - optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map) + optim_kwargs = get_optim_kwargs(base_optimizer) + optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map, **optim_kwargs) for _ in range(3): x = model.example_inputs(device=_DEVICE) out = model(x) @@ -367,11 +365,13 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): b = 4 base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) + optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False) optimizer = QuantOptimizer( base_optimizer, Int4UnifTorchaoQuantizer(), ProxHardQuant(), quant_per_channel=True, + **optim_kwargs, ) compare_parq_convert(model, m_ref, optimizer, weight_only=True) @@ -387,11 +387,13 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): quantize_(m_ref, config) base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) + optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False) optimizer = QuantOptimizer( base_optimizer, UnifTorchaoQuantizer(), ProxHardQuant(), quant_per_channel=True, + **optim_kwargs, ) compare_parq_convert(model, m_ref, optimizer, weight_only=True) check_torchao_tensor_subclass(self, model, weight_only=True) @@ -462,11 +464,13 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): quantize_(m_ref, config, filter_fn=_is_linear) base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) + optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False) optimizer = QuantOptimizer( base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True, + **optim_kwargs, ) compare_parq_convert(model, m_ref, optimizer, weight_only=True) check_torchao_tensor_subclass(self, model, weight_only=True) @@ -482,8 +486,13 @@ def test_intx_weight_only_tied_embed_linear( quantizer = StretchedUnifTorchaoQuantizer(b) base_optimizer = torch.optim.SGD(build_param_groups(model, b)) + optim_kwargs = get_optim_kwargs(base_optimizer) optimizer = QuantOptimizer( - base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True + base_optimizer, + quantizer, + ProxHardQuant(), + quant_per_channel=True, + **optim_kwargs, ) optimizer.zero_grad() optimizer.step() @@ -531,8 +540,13 @@ def test_int8_dynamic_activation_intx_e2e( # quantize weights with PARQ base_optimizer = torch.optim.SGD(build_param_groups(model, b, group_size)) + optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False) optimizer = QuantOptimizer( - base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True + base_optimizer, + quantizer, + ProxHardQuant(), + quant_per_channel=True, + **optim_kwargs, ) optimizer.zero_grad() From af5fb9720abf5059ea1c9fed2e9bc9f3561f44cb Mon Sep 17 00:00:00 2001 From: lisjin Date: Thu, 2 Oct 2025 07:58:12 -0700 Subject: [PATCH 3/4] Update torchao_convert to match notebook --- torchao/prototype/parq/optim/quantopt.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index f78133eb03..93b3dc7774 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -163,17 +163,24 @@ def _get_quantizer(self, group_idx: int) -> Optional[Quantizer]: def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: """Converts model parameters to torchao quantized tensor subclasses.""" model.eval() - self.restore_latent_params() # TODO(lvj): find more robust way to identify embedding layers embed_data_ptrs = set() linear_data_ptrs = set() + embed_modules = [] for module in model.modules(): if isinstance(module, nn.Embedding): + embed_modules.append(module) embed_data_ptrs.add(module.weight.data_ptr()) elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs: linear_data_ptrs.add(module.weight.data_ptr()) + tied_embeddings = getattr(model, "_tied_weights_keys", None) is not None + if tied_embeddings: + # Workaround for dynamic activations on tied embeddings + for module in embed_modules: + setattr(module, "bias", None) + filter_fns = [] configs = [] attach_hf_config = _is_hf_model(model) @@ -194,7 +201,7 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: any_embed = any(p.data_ptr() in embed_data_ptrs for p in group["params"]) config = _get_config_from_quantizer( quantizer, - weight_only or any_embed, + weight_only or (any_embed and not tied_embeddings), device, group["quant_bits"], group.get("quant_block_size"), From abc60fdef8bb6ed64037427ac18ff503e2b000be Mon Sep 17 00:00:00 2001 From: lisjin Date: Thu, 2 Oct 2025 11:15:35 -0700 Subject: [PATCH 4/4] Handle weight-only embeddings in torchao_convert --- test/prototype/test_parq.py | 40 +++++++++++++++++------- torchao/prototype/parq/optim/quantopt.py | 13 ++++++-- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index e9ccdb1e99..a91e119572 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -161,6 +161,7 @@ def build_param_groups( model, b: int = 2, group_size: Optional[int] = None, + embed_b: int = 4, ): params_quant, params_embed, params_no_quant = split_param_groups(model) quant_kwargs = {} @@ -171,14 +172,27 @@ def build_param_groups( {"params": params_no_quant}, ] if params_embed: - param_groups.append({"params": params_embed, "quant_bits": 4}) + param_groups.append({"params": params_embed, "quant_bits": embed_b}) return param_groups -def get_optim_kwargs(base_optimizer, embedding=True, quant_cls=UnifTorchaoQuantizer): +def get_optim_kwargs( + model, base_optimizer, embedding=True, quant_cls=UnifTorchaoQuantizer +): optim_kwargs = {} if embedding: - group_idx = len(base_optimizer.param_groups) - 2 + embed_data_ptrs = set( + ( + m.weight.data_ptr() + for m in model.modules() + if isinstance(m, nn.Embedding) + ) + ) + group_idx = -1 + for i, group in enumerate(base_optimizer.param_groups): + if all(p.data_ptr() in embed_data_ptrs for p in group["params"]): + group_idx = i + break assert group_idx > -1 optim_kwargs["group_quantizer_map"] = {group_idx: quant_cls()} return optim_kwargs @@ -221,7 +235,7 @@ def compare_parq_convert( orig_model = copy.deepcopy(model) # save copy of PARQ quantized model # equivalent to torchao's convert step - optimizer.torchao_convert(model, weight_only=weight_only) + optimizer.torchao_convert(model, weight_only=weight_only, embed_weight_only=True) inputs = model.example_inputs(device=_DEVICE) torch.testing.assert_close(model(inputs), orig_model(inputs)) @@ -289,13 +303,15 @@ def test_parq_train_loop( quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer() else: quantizer = LSBQuantizer() - param_groups = build_param_groups(model, b) + param_groups = build_param_groups(model, b, embed_b=b) base_optimizer = torch.optim.AdamW(param_groups) prox_map = ( ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2) ) - optim_kwargs = get_optim_kwargs(base_optimizer) + optim_kwargs = get_optim_kwargs( + model, base_optimizer, quant_cls=type(quantizer), embedding=False + ) optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map, **optim_kwargs) for _ in range(3): x = model.example_inputs(device=_DEVICE) @@ -365,7 +381,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): b = 4 base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) - optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False) + optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False) optimizer = QuantOptimizer( base_optimizer, Int4UnifTorchaoQuantizer(), @@ -387,7 +403,7 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): quantize_(m_ref, config) base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) - optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False) + optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False) optimizer = QuantOptimizer( base_optimizer, UnifTorchaoQuantizer(), @@ -464,7 +480,7 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): quantize_(m_ref, config, filter_fn=_is_linear) base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) - optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False) + optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False) optimizer = QuantOptimizer( base_optimizer, quantizer, @@ -486,7 +502,7 @@ def test_intx_weight_only_tied_embed_linear( quantizer = StretchedUnifTorchaoQuantizer(b) base_optimizer = torch.optim.SGD(build_param_groups(model, b)) - optim_kwargs = get_optim_kwargs(base_optimizer) + optim_kwargs = get_optim_kwargs(model, base_optimizer) optimizer = QuantOptimizer( base_optimizer, quantizer, @@ -498,7 +514,7 @@ def test_intx_weight_only_tied_embed_linear( optimizer.step() apply_activation_quantization(model, optimizer, model_dtype) - optimizer.torchao_convert(model) + optimizer.torchao_convert(model, embed_weight_only=True) check_torchao_tensor_subclass(self, model) self.assertTrue( torch.equal(model.embed_tokens.weight.qdata, model.linear2.weight.qdata) @@ -540,7 +556,7 @@ def test_int8_dynamic_activation_intx_e2e( # quantize weights with PARQ base_optimizer = torch.optim.SGD(build_param_groups(model, b, group_size)) - optim_kwargs = get_optim_kwargs(base_optimizer, embedding=False) + optim_kwargs = get_optim_kwargs(model, base_optimizer, embedding=False) optimizer = QuantOptimizer( base_optimizer, quantizer, diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index 93b3dc7774..e6eb8289d2 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -160,9 +160,15 @@ def _get_quantizer(self, group_idx: int) -> Optional[Quantizer]: return self.group_quantizer_map[group_idx] return self.quantizer - def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: + def torchao_convert( + self, + model: nn.Module, + weight_only: bool = False, + embed_weight_only: bool = False, + ) -> None: """Converts model parameters to torchao quantized tensor subclasses.""" model.eval() + self.restore_latent_params() # TODO(lvj): find more robust way to identify embedding layers embed_data_ptrs = set() @@ -175,9 +181,10 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs: linear_data_ptrs.add(module.weight.data_ptr()) - tied_embeddings = getattr(model, "_tied_weights_keys", None) is not None - if tied_embeddings: + tied_embeddings = False + if not embed_weight_only and getattr(model, "_tied_weights_keys", None): # Workaround for dynamic activations on tied embeddings + tied_embeddings = True for module in embed_modules: setattr(module, "bias", None)