diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index f99cf4a1b4..4a11a4c6e7 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -42,12 +42,14 @@ Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, Int8WeightOnlyConfig, _replace_with_custom_fn_if_matches_filter, quantize_, ) from torchao.quantization.quant_primitives import ( MappingType, + choose_qparams_affine, dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -1123,6 +1125,27 @@ def test_dynamic_quant(self): # self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear)) +class TestStaticQuant(unittest.TestCase): + def test_static_quant(self): + M, K, N = 8, 16, 8 + x = torch.randn(M, K) + m = nn.Sequential(nn.Linear(K, N)) + block_size = [M, K] # per-tensor quantization + scale, _ = choose_qparams_affine( + x, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + ) + + y_ref = m(x) + quantize_(m, Int8StaticActivationInt8WeightConfig(scale)) + y_test = m(x) + + sqnr = compute_error(y_ref, y_test) + self.assertGreater(sqnr, 40.0) + + class TestWeightOnlyInt8Quant(unittest.TestCase): def test_weight_only_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 581f75b925..3e017d32fc 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -15,8 +15,15 @@ ) from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ +from torchao.quantization.linear_activation_scale import ( + WeightTensorWithLinearActivationScaleMetadata, +) from torchao.quantization.quant_api import ( Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, +) +from torchao.quantization.utils import ( + compute_error as SQNR, ) @@ -34,16 +41,19 @@ def example_inputs( dtype=torch.bfloat16, device="cuda", ): - return [ - torch.randn( - 1, - sequence_length, - self.linear1.in_features, - dtype=dtype, - device=device, - ) - for j in range(batch_size) - ] + # For SmoothQuant tests, we intentionally insert some outliers to input features + x = torch.randn( + batch_size, + sequence_length, + self.linear1.in_features, + dtype=dtype, + device=device, + ) + n_outliers = max(1, int(x.size(-1) * 0.1)) + # Randomly select outlier features + outlier_indices = torch.randperm(x.size(-1))[:n_outliers] + x[:, :, outlier_indices] *= 10.0 + return (x,) def forward(self, x): x = self.linear1(x) @@ -52,7 +62,9 @@ def forward(self, x): return x -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] + + @unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm") class TestSmoothQuant(unittest.TestCase): """SmoothQuant tests using only supported quantization configs.""" @@ -72,37 +84,27 @@ def setUpClass(cls): # TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py ], ) - @common_utils.parametrize("device", ["cpu", "cuda"]) + @common_utils.parametrize("device", device_list) @common_utils.parametrize("input_dtype", [torch.bfloat16]) - def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): + def test_smoothquant_dynamic_act_accuracy( + self, alpha, base_config, device, input_dtype + ): """Test if SmoothQuant achieves lower loss than basic quantization.""" - in_features = 64 - out_features = 128 - - # Note: This is sanity check. For real run, consider Transformer model to reproduce. - X = torch.randn(16, in_features, dtype=input_dtype, device=device) - W = torch.randn(out_features, in_features, dtype=input_dtype, device=device) - # Create linear layer - linear = ( - torch.nn.Linear(in_features, out_features, bias=False) - .to(device) - .to(input_dtype) - ) - with torch.no_grad(): - linear.weight.copy_(W) + m = ToyLinearModel().eval().to(device).to(input_dtype) + x = m.example_inputs(batch_size=16, dtype=input_dtype, device=device) # Reference output - out_ref = linear(X) + out_ref = m(*x) # Step 1. Basic quantization - basic_model = deepcopy(linear) + basic_model = deepcopy(m) quantize_(basic_model, base_config) - out_basic = basic_model(X) + out_basic = basic_model(*x) loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item() - # SmoothQuant quantization - model = deepcopy(linear) + # Step 2. SmoothQuant + model = deepcopy(m) config = SmoothQuantConfig( base_config=base_config, step=SmoothQuantStep.PREPARE, @@ -111,23 +113,83 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): quantize_(model, config) # Perform calibration with test data - model(X) + model(*x) - # Step 2. SmoothQuant config.step = SmoothQuantStep.CONVERT quantize_(model, config) + assert isinstance( + model.linear1.weight, WeightTensorWithLinearActivationScaleMetadata + ) + assert isinstance( + model.linear2.weight, WeightTensorWithLinearActivationScaleMetadata + ) - out_smoothquant = model(X) + out_smoothquant = model(*x) loss_smoothquant = torch.nn.functional.mse_loss(out_smoothquant, out_ref).item() assert loss_smoothquant < loss_base, ( f"SmoothQuant loss ({loss_smoothquant:.6f}) should not be higher than basic loss ({loss_base:.6f})" ) + @common_utils.parametrize("alpha", [0.5, 0.25]) + @common_utils.parametrize("device", device_list) + @common_utils.parametrize("input_dtype", [torch.bfloat16]) + def test_smoothquant_static_act_accuracy(self, alpha, device, input_dtype): + """Test if SmoothQuant with static quantization achieves lower loss than basic quantization.""" + m = ToyLinearModel().eval().to(device).to(input_dtype) + x = m.example_inputs(batch_size=16, dtype=input_dtype, device=device) + + # Output without quantization + out_ref = m(*x) + + # Step 1. Reference with alpha=0 + m_ref = deepcopy(m) + base_config = Int8StaticActivationInt8WeightConfig() + config = SmoothQuantConfig( + base_config=base_config, + step=SmoothQuantStep.PREPARE, + alpha=0.0, + ) + with torch.no_grad(): + quantize_(m_ref, config) + m_ref(*x) # calibration + config.step = SmoothQuantStep.CONVERT + quantize_(m_ref, config) + out_base = m_ref(*x) + loss_base = torch.nn.functional.mse_loss(out_base, out_ref).item() + + # Step 2. SmoothQuant quantization + base_config = Int8StaticActivationInt8WeightConfig() + config = SmoothQuantConfig( + base_config=base_config, + step=SmoothQuantStep.PREPARE, + alpha=alpha, + ) + with torch.no_grad(): + quantize_(m, config) + m(*x) # calibration + config.step = SmoothQuantStep.CONVERT + quantize_(m, config) + out_sq = m(*x) + assert isinstance( + m.linear1.weight, WeightTensorWithLinearActivationScaleMetadata + ) + assert isinstance( + m.linear2.weight, WeightTensorWithLinearActivationScaleMetadata + ) + loss_smoothquant = torch.nn.functional.mse_loss(out_sq, out_ref).item() + + assert loss_smoothquant < loss_base, ( + f"SmoothQuant loss ({loss_smoothquant:.6f}) should not be higher than basic loss ({loss_base:.6f})" + ) + # Make sure the result is reasonable + self.assertGreater(SQNR(out_ref, out_sq), 20.0) + @common_utils.parametrize( "base_config", [ Int8DynamicActivationInt8WeightConfig(), + Int8StaticActivationInt8WeightConfig(), # TODO: Check more quantization APIs ], ) @@ -167,6 +229,7 @@ def test_observer_insertion(self, base_config): "base_config", [ Int8DynamicActivationInt8WeightConfig(), + Int8StaticActivationInt8WeightConfig(), # TODO: Check more quantization APIs ], ) diff --git a/torchao/prototype/smoothquant/README.md b/torchao/prototype/smoothquant/README.md index 00e819c438..c5aeaea78a 100644 --- a/torchao/prototype/smoothquant/README.md +++ b/torchao/prototype/smoothquant/README.md @@ -50,6 +50,7 @@ for data in calibration_dataset: quant_config.step = SmoothQuantStep.CONVERT quantize_(model, quant_config) ``` +For static quantization of activation, use `Int8StaticActivationInt8WeightConfig` instead of `Int8DynamicActivationInt8WeightConfig`. Generally, static quantization produces better througput at the cost of accuracy (higher perplexity). ## Benchmarks diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9f78c49fb8..b4a0f99e91 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -15,6 +15,7 @@ ) from torchao.quantization.quant_api import ( _QUANTIZE_CONFIG_HANDLER, + Int8StaticActivationInt8WeightConfig, _linear_extra_repr, ) from torchao.quantization.transform_module import ( @@ -96,7 +97,12 @@ def _smooth_quant_transform( raise ValueError(f"Unexpected step: {step}") # Compute smoothed weight parameters - smoothing_factor = observed_linear.obs.calculate_qparams() + act_quant_min, act_quant_max = None, None + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + act_quant_min, act_quant_max = -127, 127 + smoothing_factor, act_scale = observed_linear.obs.calculate_qparams( + act_quant_min, act_quant_max + ) weight = observed_linear.weight * smoothing_factor # Create new linear layer @@ -111,6 +117,8 @@ def _smooth_quant_transform( linear.bias = observed_linear.bias # Quantize weights + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + base_config = Int8StaticActivationInt8WeightConfig(act_scale) base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)] dummy_mod = DummyModule(weight) quant_mod = base_config_handler(dummy_mod, base_config) diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 83f1e78275..0d93db69d5 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -41,7 +41,7 @@ def forward(self, input: torch.Tensor): self.inputs.append(input.to("cpu")) return input - def calculate_qparams(self): + def calculate_qparams(self, act_quant_min=None, act_quant_max=None): assert self.inputs and len(self.inputs) > 0, ( "calibrate observer first by running model on exemplar data" ) @@ -54,15 +54,19 @@ def calculate_qparams(self): # Calculate per-channel max values x_abs_max = torch.max(torch.abs(acc), dim=0)[0] w_abs_max = torch.max(torch.abs(self.weight), dim=0)[0] + act_scale = None + if act_quant_min is not None and act_quant_max is not None: + x_abs_max_t = acc.abs().max() + act_scale = x_abs_max_t / (act_quant_max - act_quant_min) / 2 # Calculate smoothing factor if self.alpha is None: - return torch.ones_like(x_abs_max) + return torch.ones_like(x_abs_max), act_scale eps = torch.finfo(torch.float32).eps return torch.pow(x_abs_max + eps, self.alpha) / torch.pow( w_abs_max + eps, 1 - self.alpha - ) + ), act_scale class SmoothQuantObservedLinear(torch.nn.Linear): diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index 8602b57e20..b3b5bffa93 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -16,7 +16,10 @@ ) from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ -from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig +from torchao.quantization.quant_api import ( + Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, +) # TODO: Build benchmark within vLLM ecosystem with more quantization APIs @@ -82,13 +85,15 @@ def quantize_and_eval( device: str, model_save_path: str, model_save_hf_hub_path: str, + static_quant_act: bool, + compile: bool, ): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() tokenizer = AutoTokenizer.from_pretrained(model_id) model = ( - AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16) + AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) .eval() .to(device) ) @@ -96,9 +101,14 @@ def quantize_and_eval( # Step 1: Prepare - insert observers print("running SmoothQuant prepare and calibrate") + base_config = ( + Int8StaticActivationInt8WeightConfig() + if static_quant_act + else Int8DynamicActivationInt8WeightConfig() + ) t0 = time.time() quant_config = SmoothQuantConfig( - base_config=Int8DynamicActivationInt8WeightConfig(), + base_config=base_config, step=SmoothQuantStep.PREPARE, alpha=alpha, ) @@ -134,6 +144,9 @@ def quantize_and_eval( model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) tokenizer.push_to_hub(model_save_hf_hub_path) + if compile: + model.forward = torch.compile(model.forward, dynamic=True) + print("Benchmarking SmoothQuant model...") return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) @@ -147,6 +160,8 @@ def compare_models( device: str, model_save_path: str, model_save_hf_hub_path: str, + static_quant_act: bool, + compile: bool, ): """Compare perplexity and speed for behchmarking SmoothQuant""" @@ -155,10 +170,12 @@ def compare_models( torch.manual_seed(34) tokenizer = AutoTokenizer.from_pretrained(model_id) model = ( - AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16) + AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) .eval() .to(device) ) + if compile: + model.forward = torch.compile(model.forward, dynamic=True) base_results = benchmark( model, tokenizer, max_seq_length, tasks=tasks, device=device ) @@ -167,11 +184,13 @@ def compare_models( print("Benchmarking W8A8-dynamic without SmoothQuant...") torch.manual_seed(34) w8a8_model = ( - AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16) + AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) .eval() .to(device) ) quantize_(w8a8_model, Int8DynamicActivationInt8WeightConfig()) + if compile: + w8a8_model.forward = torch.compile(w8a8_model.forward, dynamic=True) w8a8_results = benchmark( w8a8_model, tokenizer, max_seq_length, tasks=tasks, device=device ) @@ -187,6 +206,8 @@ def compare_models( device, model_save_path, model_save_hf_hub_path, + static_quant_act, + compile, ) # Calculate changes and display results @@ -289,6 +310,16 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Huggingface hub path to store the quantized model and tokenizer.", ) + parser.add_argument( + "--static-quant-act", + action="store_true", + help="Use static quantization of activation instead of dynamic quantization.", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Use torch.compile to compile the model for potentially better performance.", + ) return parser @@ -306,4 +337,6 @@ def create_parser() -> argparse.ArgumentParser: args.device, args.model_save_path, args.model_save_hf_hub_path, + args.static_quant_act, + args.compile, ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 15caddcadc..a7dbaf4cd9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -46,6 +46,7 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, + to_affine_quantized_intx_static, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -1611,6 +1612,108 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()) +def _activation_static_sym_quant_func_int8( + x: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None +) -> torch.Tensor: + assert zero_point is None, "Zero point must be None" + quant_min = -127 + quant_max = 127 + target_dtype = torch.int8 + zero_point_domain = ZeroPointDomain.NONE + + return to_affine_quantized_intx_static( + x, + scale=scale, + zero_point=zero_point, + block_size=get_block_size(x.shape, PerTensor()), + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + zero_point_domain=zero_point_domain, + ) + + +@dataclass +class Int8StaticActivationInt8WeightConfig(AOBaseConfig): + """ + Configuration for applying int8 static activation and int8 per-channel weight + quantization to linear layers. Activation is always quantized with symmetric per-tensor quantization. + + Args: + layout: Optional[Layout] = PlainLayout() - Tensor layout for the quantized weights. Controls how the + quantized data is stored and accessed. Only PlainLayout is supported now. + set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values + for better performance with this quantization scheme. + """ + + act_scale: Optional[torch.Tensor] = None + layout: Optional[Layout] = PlainLayout() + set_inductor_config: bool = False + + def __post_init__(self): + assert isinstance(self.layout, PlainLayout), ( + f"Only support PlainLayout for layout, got {self.layout}" + ) + torch._C._log_api_usage_once( + "torchao.quantization.Int8StaticActivationInt8WeightConfig" + ) + + +def _int8_static_activation_int8_weight_quantize_tensor(weight, config): + act_scale = config.act_scale + layout = config.layout + + # weight settings + mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE + + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + block_size = get_block_size(weight.shape, PerRow()) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, + ) + new_weight = to_weight_tensor_with_linear_activation_quantization_metadata( + new_weight, + _activation_static_sym_quant_func_int8, + scale=act_scale, + zero_point=None, + ) + return new_weight + + +@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig) +def _int8_static_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8StaticActivationInt8WeightConfig +) -> torch.nn.Module: + assert config.act_scale is not None, ( + "act_scale must be provided for static activation quantization" + ) + assert config.act_scale.numel() == 1, "Only support per-tensor quantization" + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying int8 static activation int8 weight quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + new_weight = _int8_static_activation_int8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + @dataclass class Float8WeightOnlyConfig(AOBaseConfig): """