From c2daf5665cb42daf2687745a061e2535d4ec4f1a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 21 Nov 2025 07:41:56 -0800 Subject: [PATCH] Add PerBlock to safe globals **Summary:** Add PerBlock to safe globals so users don't have to do this themselves when they load config.json with PerBlock. ``` WeightsUnpickler error: Unsupported global: GLOBAL torchao.quantization.granularity.PerBlock was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchao.quantization.granularity.PerBlock])` or the `torch.serialization.safe_globals([torchao.quantization.granularity.PerBlock])` context manager to allowlist this global if you trust this class/function. ``` **Test Plan:** ``` python test/core/test_config.py -k test_granularity_serialization ``` --- test/core/test_config.py | 47 ++++++++++++++++++- .../test_affine_quantized_tensor_parallel.py | 3 +- torchao/quantization/granularity.py | 5 ++ torchao/quantization/observer.py | 10 +--- torchao/quantization/quant_api.py | 2 +- 5 files changed, 54 insertions(+), 13 deletions(-) diff --git a/test/core/test_config.py b/test/core/test_config.py index 3fb9d435fa..2957b1d1ff 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -6,6 +6,7 @@ import json import os +import subprocess import tempfile import warnings from dataclasses import dataclass @@ -23,7 +24,11 @@ AWQConfig, AWQStep, ) -from torchao.quantization import PerBlock +from torchao.quantization import ( + PerBlock, + PerRow, + PerTensor, +) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, @@ -36,10 +41,11 @@ Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, ModuleFqnToConfig, - PerRow, UIntXWeightOnlyConfig, + quantize_, ) from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig +from torchao.utils import is_sm_at_least_89 # Define test configurations as fixtures configs = [ @@ -155,6 +161,43 @@ def test_reconstructable_dict_file_round_trip(config): os.unlink(temp_file_path) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not is_sm_at_least_89(), reason="needs CUDA capability 8.9+") +@pytest.mark.parametrize( + "granularity", + [ + PerTensor(), + PerRow(), + (PerBlock([1, 128]), PerBlock([128, 128])), + ], +) +def test_granularity_serialization(granularity): + """ + Ensure that only `import torchao` is needed to load granularities used + in `Float8DynamicActivationFloat8WeightConfig`. + """ + + m = torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda") + fname = None + with tempfile.NamedTemporaryFile(delete=False, mode="w") as f: + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + quantize_(m, config=config) + torch.save(m.state_dict(), f.name) + fname = f.name + + assert fname is not None + + code = f""" +import torch +import torchao +_ = torch.load('{fname}', weights_only=True) + """ + + subprocess_out = subprocess.run(["python"], input=code, text=True) + os.remove(fname) + assert subprocess_out.returncode == 0, "failed weights-only load" + + # Define a dummy config in a non-allowed module @dataclass class DummyNonAllowedConfig(AOBaseConfig): diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 5a6e89c25f..2f234c16f3 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -21,8 +21,9 @@ Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, + PerRow, + PerTensor, ) -from torchao.quantization.observer import PerRow, PerTensor from torchao.quantization.quant_api import quantize_ if common_utils.SEED is None: diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index 97d9c07b6f..f871649f48 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -6,6 +6,8 @@ from dataclasses import dataclass +import torch + @dataclass(frozen=True) class Granularity: @@ -138,3 +140,6 @@ class PerBlock(Granularity): # list. Example error: # https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c block_size: tuple[int, ...] + + +torch.serialization.add_safe_globals([PerBlock, PerRow, PerTensor]) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index d12ffaf520..0e18770ae5 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -12,11 +12,7 @@ from torchao.quantization.quant_primitives import _fake_quantize_affine -from .granularity import ( - Granularity, - PerRow, - PerTensor, -) +from .granularity import Granularity from .quant_primitives import ( MappingType, ZeroPointDomain, @@ -350,7 +346,3 @@ def calculate_qparams(self): self.preserve_zero, self.zero_point_domain, ) - - -# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([PerRow, PerTensor]) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 1e176f9e9b..6c7ca9658f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -332,10 +332,10 @@ def insert_observers_( ``` import torch import torch.nn as nn + from torchao.quantization import PerTensor from torchao.quantization.linear_observer_tensor import insert_observers_ from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, MappingType )