Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions test/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import json
import os
import subprocess
import tempfile
import warnings
from dataclasses import dataclass
Expand All @@ -23,7 +24,11 @@
AWQConfig,
AWQStep,
)
from torchao.quantization import PerBlock
from torchao.quantization import (
PerBlock,
PerRow,
PerTensor,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Expand All @@ -36,10 +41,11 @@
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
ModuleFqnToConfig,
PerRow,
UIntXWeightOnlyConfig,
quantize_,
)
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
from torchao.utils import is_sm_at_least_89

# Define test configurations as fixtures
configs = [
Expand Down Expand Up @@ -155,6 +161,43 @@ def test_reconstructable_dict_file_round_trip(config):
os.unlink(temp_file_path)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not is_sm_at_least_89(), reason="needs CUDA capability 8.9+")
@pytest.mark.parametrize(
"granularity",
[
PerTensor(),
PerRow(),
(PerBlock([1, 128]), PerBlock([128, 128])),
],
)
def test_granularity_serialization(granularity):
"""
Ensure that only `import torchao` is needed to load granularities used
in `Float8DynamicActivationFloat8WeightConfig`.
"""

m = torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda")
fname = None
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
quantize_(m, config=config)
torch.save(m.state_dict(), f.name)
fname = f.name

assert fname is not None

code = f"""
import torch
import torchao
_ = torch.load('{fname}', weights_only=True)
"""
Comment on lines +190 to +194
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these instead of just code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because by the time we run the test we have already imported everything. This starts a fresh environment and shows you only need to import torchao for loading to work. Copied from: https://github.com/pytorch/ao/blob/main/test/prototype/mx_formats/test_mx_serialization.py#L36


subprocess_out = subprocess.run(["python"], input=code, text=True)
os.remove(fname)
assert subprocess_out.returncode == 0, "failed weights-only load"


# Define a dummy config in a non-allowed module
@dataclass
class DummyNonAllowedConfig(AOBaseConfig):
Expand Down
3 changes: 2 additions & 1 deletion test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
PerRow,
PerTensor,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.quant_api import quantize_

if common_utils.SEED is None:
Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dataclasses import dataclass

import torch


@dataclass(frozen=True)
class Granularity:
Expand Down Expand Up @@ -138,3 +140,6 @@ class PerBlock(Granularity):
# list. Example error:
# https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c
block_size: tuple[int, ...]


torch.serialization.add_safe_globals([PerBlock, PerRow, PerTensor])
10 changes: 1 addition & 9 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

from torchao.quantization.quant_primitives import _fake_quantize_affine

from .granularity import (
Granularity,
PerRow,
PerTensor,
)
from .granularity import Granularity
from .quant_primitives import (
MappingType,
ZeroPointDomain,
Expand Down Expand Up @@ -350,7 +346,3 @@ def calculate_qparams(self):
self.preserve_zero,
self.zero_point_domain,
)


# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([PerRow, PerTensor])
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,10 @@ def insert_observers_(
```
import torch
import torch.nn as nn
from torchao.quantization import PerTensor
from torchao.quantization.linear_observer_tensor import insert_observers_
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
MappingType
)

Expand Down
Loading