Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 80 additions & 21 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,16 @@ def split_param_groups(model) -> tuple[list, list, list]:
params_quant, params_embed, params_no_quant = [], [], []

def get_param_groups(model):
seen_data_ptrs = set() # avoid duplicates in case of tied weights
for module in model.children():
is_linear = _is_linear(module)
for n, p in module.named_parameters():
if n == "weight":
data_ptr = p.data_ptr()
if data_ptr in seen_data_ptrs:
continue
seen_data_ptrs.add(data_ptr)

if is_linear and n == "weight":
params_quant.append(p)
elif isinstance(module, nn.Embedding) and n == "weight":
Expand Down Expand Up @@ -152,7 +159,12 @@ def compare_parq_convert(
def check_torchao_tensor_subclass(
test_case: common_utils.TestCase, model: nn.Module, weight_only: bool = False
):
for module in model.modules():
for name, module in model.named_modules():
if not hasattr(module, "weight") or f"{name}.weight" in getattr(
model, "_tied_weights_keys", []
):
continue

if not weight_only and _is_linear(module):
test_case.assertTrue(isinstance(module.weight, IntxUnpackedToInt8Tensor))
test_case.assertTrue(
Expand All @@ -163,34 +175,58 @@ def check_torchao_tensor_subclass(
test_case.assertTrue(module.weight.activation_quantization is None)


def apply_activation_quantization(
model: nn.Module, optimizer: torch.optim.Optimizer, model_dtype: torch.dtype
):
# apply torchao quantized activations on top
activation_config = IntxFakeQuantizeConfig(
torch.int8, "per_token", is_symmetric=False, scale_precision=model_dtype
)
qat_config = QATConfig(activation_config=activation_config, step="prepare")
for filter_fn in optimizer.get_filter_fns(model):
try:
quantize_(model, qat_config, filter_fn=filter_fn)
except ValueError as e:
if str(e) == "Activation fake quantization is not supported for embedding":
pass


class M(nn.Module):
def __init__(self, m=256, n=128, k=16, bias=False, embedding=True):
_tied_weights_keys: list[str] = []

def __init__(
self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False
):
super().__init__()
self.embedding = nn.Embedding(10, m) if embedding else nn.Identity()
self.embedding = nn.Embedding(k, m) if embedding else nn.Identity()
self.linear1 = nn.Linear(m, n, bias=bias)
self.linear2 = nn.Linear(n, k, bias=bias)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

if embedding and tied_weights:
assert self.embedding.weight.shape == self.linear2.weight.shape
self.linear2.weight = self.embedding.weight
self._tied_weights_keys.append("linear2.weight")

def reset_parameters(self):
for module in (self.linear1, self.linear2):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)

def example_inputs(self, device=None):
return (
torch.randint(1, 10, (1, self.linear1.in_features), device=device)
if isinstance(self.embedding, nn.Embedding)
else torch.randn(1, self.linear1.in_features, device=device)
)
if isinstance(self.embedding, nn.Identity):
inputs = torch.randn(1, self.linear1.in_features, device=device)
else:
k = self.embedding.num_embeddings
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
return inputs

def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.sigmoid(x)
x = self.relu(self.linear1(x))
x = self.sigmoid(self.linear2(x))
return x


Expand Down Expand Up @@ -297,7 +333,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
ProxHardQuant(),
quant_per_channel=True,
)
compare_parq_convert(model, m_ref, optimizer)
compare_parq_convert(model, m_ref, optimizer, weight_only=True)

@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
@common_utils.parametrize("b", [2, 3, 4, 8])
Expand Down Expand Up @@ -399,6 +435,30 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
compare_parq_convert(model, m_ref, optimizer, weight_only=True)
check_torchao_tensor_subclass(self, model, weight_only=True)

@common_utils.parametrize("b", [2, 3])
@common_utils.parametrize(
"model_dtype", [torch.float16, torch.float32, torch.bfloat16]
)
def test_intx_weight_only_tied_embed_linear(
self, b: int = 2, model_dtype: torch.dtype = torch.float32
):
model = M(m=256, n=256, tied_weights=True).to(_DEVICE)

quantizer = StretchedUnifTorchaoQuantizer(b)
base_optimizer = torch.optim.SGD(build_param_groups(model, b))
optimizer = QuantOptimizer(
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
)
optimizer.zero_grad()
optimizer.step()

apply_activation_quantization(model, optimizer, model_dtype)
optimizer.torchao_convert(model)
check_torchao_tensor_subclass(self, model)
self.assertTrue(
torch.equal(model.embedding.weight.qdata, model.linear2.weight.qdata)
)


class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
def setUp(self):
Expand Down Expand Up @@ -435,16 +495,12 @@ def test_int8_dynamic_activation_intx_e2e(
optimizer = QuantOptimizer(
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
)

optimizer.zero_grad()
optimizer.step()

# apply torchao quantized activations on top
activation_config = IntxFakeQuantizeConfig(
torch.int8, "per_token", is_symmetric=False, scale_precision=model_dtype
)
qat_config = QATConfig(activation_config=activation_config, step="prepare")
for filter_fn in optimizer.get_filter_fns(model):
quantize_(model, qat_config, filter_fn=filter_fn)
apply_activation_quantization(model, optimizer, model_dtype)

out = model(x)
torch.testing.assert_close(out, ref_out, atol=0, rtol=0)

Expand All @@ -462,7 +518,10 @@ def test_int8_dynamic_activation_intx_e2e(
check_torchao_tensor_subclass(self, model)

if attach_hf_config:
reg_param_names = {n for n, m in model.named_modules() if _is_linear(m)}
reg_param_names = {
n for n, m in model.named_modules() if isinstance(m, nn.Embedding)
}
reg_param_names.add("_default")
module_fqn_to_config = (
model.config.quantization_config.quant_type.module_fqn_to_config
)
Expand Down
1 change: 1 addition & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
"torchao.prototype.parq",
"torchao.dtypes",
"torchao.prototype.awq",
"torchao.prototype.parq.quant",
"torchao.quantization.quantize_.common",
"torchao.quantization.quantize_.workflows",
}
Expand Down
36 changes: 27 additions & 9 deletions torchao/prototype/parq/optim/quantopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.optim import Optimizer

from torchao.quantization import quantize_
from torchao.quantization.quant_api import _is_linear

from ..quant import Quantizer, UnifTorchaoQuantizer
from ..quant.config_torchao import (
Expand Down Expand Up @@ -158,24 +159,30 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
self.restore_latent_params()

# TODO(lvj): find more robust way to identify embedding layers
embed_data_ptrs = {
module.weight.data_ptr()
for module in model.modules()
if isinstance(module, nn.Embedding)
}
embed_data_ptrs = set()
linear_data_ptrs = set()
for module in model.modules():
if isinstance(module, nn.Embedding):
embed_data_ptrs.add(module.weight.data_ptr())
elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs:
linear_data_ptrs.add(module.weight.data_ptr())

filter_fns = []
configs = []
attach_hf_config = _is_hf_model(model)
for group, filter_fn in zip(
self.regularized_param_groups(), self.get_filter_fns(model)
all_linear_layers_idx = -1
for i, (group, filter_fn) in enumerate(
zip(self.regularized_param_groups(), self.get_filter_fns(model))
):
filter_fns.append(filter_fn)
quantizer = group.get("quantizer", self.quantizer)
if not isinstance(quantizer, UnifTorchaoQuantizer) or not group["params"]:
configs.append(None)
continue

if set((p.data_ptr() for p in group["params"])) == linear_data_ptrs:
all_linear_layers_idx = i

device = group["params"][0].device
any_embed = any(p.data_ptr() in embed_data_ptrs for p in group["params"])
config = _get_config_from_quantizer(
Expand All @@ -187,10 +194,21 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
)
configs.append(config)

filter_fns_orig = filter_fns[:]
configs_orig = configs[:]

# If one group has all the linear layers, then set its config as default
if all_linear_layers_idx > -1:
module_to_config = {"_default": configs[all_linear_layers_idx]}
del filter_fns[all_linear_layers_idx]
del configs[all_linear_layers_idx]
else:
module_to_config = None

if attach_hf_config:
_attach_hf_quantization_config(model, filter_fns, configs)
_attach_hf_quantization_config(model, filter_fns, configs, module_to_config)

for config, filter_fn in zip(configs, filter_fns):
for config, filter_fn in zip(configs_orig, filter_fns_orig):
quantize_(model, config, filter_fn=filter_fn)

@torch._disable_dynamo
Expand Down
Loading
Loading