Skip to content

Commit fd31a98

Browse files
committed
Add PerBlock to safe globals
**Summary:** Add PerBlock to safe globals so users don't have to do this themselves when they load config.json with PerBlock. ``` WeightsUnpickler error: Unsupported global: GLOBAL torchao.quantization.granularity.PerBlock was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchao.quantization.granularity.PerBlock])` or the `torch.serialization.safe_globals([torchao.quantization.granularity.PerBlock])` context manager to allowlist this global if you trust this class/function. ``` **Test Plan:** ``` python test/core/test_config.py -k test_granularity_serialization ```
1 parent 2ff1eb2 commit fd31a98

File tree

5 files changed

+50
-13
lines changed

5 files changed

+50
-13
lines changed

test/core/test_config.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import json
88
import os
9+
import subprocess
910
import tempfile
1011
import warnings
1112
from dataclasses import dataclass
@@ -23,7 +24,11 @@
2324
AWQConfig,
2425
AWQStep,
2526
)
26-
from torchao.quantization import PerBlock
27+
from torchao.quantization import (
28+
PerBlock,
29+
PerRow,
30+
PerTensor,
31+
)
2732
from torchao.quantization.quant_api import (
2833
Float8DynamicActivationFloat8WeightConfig,
2934
Float8DynamicActivationInt4WeightConfig,
@@ -36,8 +41,8 @@
3641
Int8DynamicActivationInt8WeightConfig,
3742
Int8WeightOnlyConfig,
3843
ModuleFqnToConfig,
39-
PerRow,
4044
UIntXWeightOnlyConfig,
45+
quantize_,
4146
)
4247
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
4348

@@ -155,6 +160,41 @@ def test_reconstructable_dict_file_round_trip(config):
155160
os.unlink(temp_file_path)
156161

157162

163+
@pytest.mark.parametrize(
164+
"granularity",
165+
[
166+
PerTensor(),
167+
PerRow(),
168+
(PerBlock([1, 128]), PerBlock([128, 128])),
169+
],
170+
)
171+
def test_granularity_serialization(granularity):
172+
"""
173+
Ensure that only `import torchao` is needed to load granularities used
174+
in `Float8DynamicActivationFloat8WeightConfig`.
175+
"""
176+
177+
m = torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda")
178+
fname = None
179+
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
180+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
181+
quantize_(m, config=config)
182+
torch.save(m.state_dict(), f.name)
183+
fname = f.name
184+
185+
assert fname is not None
186+
187+
code = f"""
188+
import torch
189+
import torchao
190+
_ = torch.load('{fname}', weights_only=True)
191+
"""
192+
193+
subprocess_out = subprocess.run(["python"], input=code, text=True)
194+
os.remove(fname)
195+
assert subprocess_out.returncode == 0, "failed weights-only load"
196+
197+
158198
# Define a dummy config in a non-allowed module
159199
@dataclass
160200
class DummyNonAllowedConfig(AOBaseConfig):

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Int8DynamicActivationInt8WeightConfig,
2323
Int8WeightOnlyConfig,
2424
)
25-
from torchao.quantization.observer import PerRow, PerTensor
25+
from torchao.quantization.granularity import PerRow, PerTensor
2626
from torchao.quantization.quant_api import quantize_
2727

2828
if common_utils.SEED is None:

torchao/quantization/granularity.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from dataclasses import dataclass
88

9+
import torch
10+
911

1012
@dataclass(frozen=True)
1113
class Granularity:
@@ -138,3 +140,6 @@ class PerBlock(Granularity):
138140
# list. Example error:
139141
# https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c
140142
block_size: tuple[int, ...]
143+
144+
145+
torch.serialization.add_safe_globals([PerBlock, PerRow, PerTensor])

torchao/quantization/observer.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212

1313
from torchao.quantization.quant_primitives import _fake_quantize_affine
1414

15-
from .granularity import (
16-
Granularity,
17-
PerRow,
18-
PerTensor,
19-
)
15+
from .granularity import Granularity
2016
from .quant_primitives import (
2117
MappingType,
2218
ZeroPointDomain,
@@ -350,7 +346,3 @@ def calculate_qparams(self):
350346
self.preserve_zero,
351347
self.zero_point_domain,
352348
)
353-
354-
355-
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
356-
torch.serialization.add_safe_globals([PerRow, PerTensor])

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,10 @@ def insert_observers_(
332332
```
333333
import torch
334334
import torch.nn as nn
335+
from torchao.quantization.granularity import PerTensor
335336
from torchao.quantization.linear_observer_tensor import insert_observers_
336337
from torchao.quantization.observer import (
337338
AffineQuantizedMinMaxObserver,
338-
PerTensor,
339339
MappingType
340340
)
341341

0 commit comments

Comments
 (0)