diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 506cec9dea..d608d1f357 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -853,6 +853,54 @@ def test_config_deprecation(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") class TestFqnToConfig(TestCase): + def test_fqn_to_config_repr_custom(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_parameter( + "x", torch.nn.Parameter(torch.randn(128, 128, dtype=torch.bfloat16)) + ) + self.register_parameter( + "y", torch.nn.Parameter(torch.randn(128, 128, dtype=torch.bfloat16)) + ) + + custom_module = TestModule().cuda().eval() + custom_module_config = FqnToConfig( + { + "x": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + quantize_( + custom_module, + custom_module_config, + filter_fn=None, + ) + assert str(custom_module).startswith("TestModule(x=Float8Tensor(") + assert str(custom_module.x) in str(custom_module) + + def test_fqn_to_config_repr_linear(self): + linear_model = ToyLinearModel().to(torch.bfloat16).cuda().eval() + linear_quant_config = FqnToConfig( + { + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + quantize_( + linear_model, + linear_quant_config, + filter_fn=None, + ) + expected_starting_str = ( + "Linear(in_features=64, out_features=32, bias=False, weight=Float8Tensor(" + ) + + assert str(linear_model).startswith(expected_starting_str) + assert str(linear_model.linear1.weight) in str(linear_model) + def test_quantize_param_fqn_exact(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 09c2edcd9f..73b88fd215 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -21,6 +21,7 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field +from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union from typing import OrderedDict as OrderedDictType @@ -414,6 +415,19 @@ def _embedding_extra_repr(self): return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}" +def _module_extra_repr(self, original_extra_repr, parameter_name): + module_torchao_extra_repr = [] + + original_extra_repr_str = original_extra_repr() + if len(original_extra_repr_str) > 0: + module_torchao_extra_repr.append(original_extra_repr_str) + + module_torchao_extra_repr.append( + f"{parameter_name}={_quantization_type(getattr(self, parameter_name))}" + ) + return ", ".join(module_torchao_extra_repr) + + def _get_linear_subclass_inserter( constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs ): @@ -1373,11 +1387,22 @@ def _int8_weight_only_transform( "applying int8 weight only quant requires module to have {parameter_name} attribute" + " but {module} does not have one" ) - new_weight = _int8_weight_only_quantize_tensor( + quantized_tensor = _int8_weight_only_quantize_tensor( getattr(module, parameter_name), config ) - setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False)) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + setattr( + module, + parameter_name, + torch.nn.Parameter(quantized_tensor, requires_grad=False), + ) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) return module @@ -1662,16 +1687,23 @@ def _float8_weight_only_transform( if isinstance(module, Float8Linear): module = _unwrap_float8_linear(module) - new_weight = _float8_weight_only_quant_tensor( + quantized_tensor = _float8_weight_only_quant_tensor( getattr(module, parameter_name), config ) setattr( module, parameter_name, - torch.nn.Parameter(new_weight, requires_grad=False), + torch.nn.Parameter(quantized_tensor, requires_grad=False), + ) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, ) - module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -1918,7 +1950,14 @@ def _float8_dynamic_activation_float8_weight_transform( parameter_name, torch.nn.Parameter(quantized_tensor, requires_grad=False), ) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) return module