Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,56 @@ def reset_memory():
assert param.is_cuda
self.assertLess(memory_streaming, memory_baseline)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_fqn_config_quantized_nested_module(self):
class NestedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16)

class TopLevelModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested = NestedModule()
self.linear1 = torch.nn.Linear(16, 16)

m = TopLevelModule()
quant_config = FqnToConfig(
{
"nested.linear": Int8WeightOnlyConfig(),
"linear1": Int8WeightOnlyConfig(),
Comment on lines +1141 to +1142
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use new tensors, e.g. Int4WeightOnlyConfig (with Int4Tensor), since we are moving away from AQT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is missed I think @jcaip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry didn't see this - let me update all of the configs in the tests

}
)
quantize_(m, quant_config, filter_fn=None)

assert isinstance(m.nested.linear.weight, AffineQuantizedTensor)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_fqn_config_quantized_nested_module_param(self):
class NestedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16)

class TopLevelModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested = NestedModule()
self.linear1 = torch.nn.Linear(16, 16)

m = TopLevelModule()
quant_config = FqnToConfig(
{
"nested.linear.weight": Int8WeightOnlyConfig(),
"linear1.weight": Int8WeightOnlyConfig(),
}
)
quantize_(m, quant_config, filter_fn=None)

assert isinstance(m.nested.linear.weight, AffineQuantizedTensor)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)


if __name__ == "__main__":
unittest.main()
5 changes: 1 addition & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,8 @@ def quantize_(
or _module_param_matches_fqn_config(module, module_fqn, config)
or ("_default" in config.fqn_to_config and _is_linear(module))
):
module_name = (
module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn
)
# this replaces inplace, so no need to reassign
_fqn_to_config_handler(module, module_name, config)
_fqn_to_config_handler(module, module_fqn, config)
if device is not None:
module.to(device=device)
return
Expand Down
Loading