From 8454ccc34be8c513a7473e4f652dce14c90bc574 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 11 Apr 2025 05:04:29 -0700 Subject: [PATCH 01/18] Enabling MOE Quantization using linear decomposition Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. 2 methods of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo, fp8dq, the underlying quantized tensor subclass was adapted to both support 3D tensors, indexing and slicing, as well as an updated transformation function that can handle the ConditionalFeedForwardAOQuantizable modules if the filter funciton in quantize_ is used to target the aforementioned module. For some complex kernels which use packed data that couldn't be made to easily work in 3D, we also added FakeExtraDimTensor which can transform any quantized tensor subclass into supporting the necessary slice and index operations for moe quantization. This option is enabled by using MoeQuantConfig. This can be applied to huggingface llama4 for instance as shown int he llama4_quant.py example. Since the hf moe module is implemented in a way that's not condusive to quantization, it first requires a module swap to the MOEFeedForwardAOQuantizable. TODO final benchmark numbers from run.sh, consolidate 3x implementation of MOEFeedForwardAOQuantizable and ConditionalFeedForwardAOQuantizable. verify hqq Test Plan: python test/quantization/test_moe_quant.py python test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py -k "test_moe_quant_intx" sh torchao/_models/mixtral-moe/run.sh Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_moe_quant.py | 288 +++++++++++ torchao/_models/mixtral-moe/generate.py | 483 ++++++++++++++++++ torchao/_models/mixtral-moe/model.py | 456 +++++++++++++++++ torchao/_models/mixtral-moe/run.sh | 39 ++ .../scripts/convert_hf_checkpoint.py | 115 +++++ .../_models/mixtral-moe/scripts/download.py | 48 ++ torchao/dtypes/affine_quantized_tensor_ops.py | 60 ++- torchao/dtypes/floatx/float8_layout.py | 7 + torchao/dtypes/uintx/plain_layout.py | 8 + .../dtypes/uintx/tensor_core_tiled_layout.py | 146 ++++-- ...est_int8_dynamic_activation_intx_weight.py | 34 ++ .../linear_activation_quantized_tensor.py | 30 ++ .../prototype/moe_quant/README.md | 45 ++ .../prototype/moe_quant/__init__.py | 0 .../prototype/moe_quant/llama4_quant.py | 89 ++++ .../moe_quant/quantizable_moe_modules.py | 186 +++++++ .../quantization/prototype/moe_quant/utils.py | 261 ++++++++++ torchao/quantization/quant_api.py | 236 ++++++--- torchao/quantization/transform_module.py | 1 + torchao/quantization/utils.py | 13 +- torchao/utils.py | 4 +- 21 files changed, 2419 insertions(+), 130 deletions(-) create mode 100644 test/quantization/test_moe_quant.py create mode 100644 torchao/_models/mixtral-moe/generate.py create mode 100644 torchao/_models/mixtral-moe/model.py create mode 100644 torchao/_models/mixtral-moe/run.sh create mode 100644 torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py create mode 100644 torchao/_models/mixtral-moe/scripts/download.py create mode 100644 torchao/quantization/prototype/moe_quant/README.md create mode 100644 torchao/quantization/prototype/moe_quant/__init__.py create mode 100644 torchao/quantization/prototype/moe_quant/llama4_quant.py create mode 100644 torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py create mode 100644 torchao/quantization/prototype/moe_quant/utils.py diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py new file mode 100644 index 0000000000..cbeb3ad308 --- /dev/null +++ b/test/quantization/test_moe_quant.py @@ -0,0 +1,288 @@ +import unittest + +import torch +from parameterized import parameterized + +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl +from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl +from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import ( + MOEFeedForwardAOQuantizable, +) +from torchao.quantization.prototype.moe_quant.utils import ( + FakeExtraDimTensor, + MoEQuantConfig, + cond_ffn_filter, +) +from torchao.quantization.quant_api import ( + AffineQuantizedTensor, + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + LinearActivationQuantizedTensor, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_90 + + +class TestMoEQuantCompile(unittest.TestCase): + DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k + + @torch.no_grad() + def _test_impl_moe_quant( + self, + config, + num_tokens=1, + model_params=None, + base_class=AffineQuantizedTensor, + tensor_impl_class=None, + dtype=torch.bfloat16, + device="cuda", + fullgraph=False, + ): + """ + Tests moe quant for techniques using fake extra dim + """ + if model_params is None: + model_params = self.DEFAULT_PARAMS + + input_shape = (num_tokens, model_params[0]) + model = ( + MOEFeedForwardAOQuantizable(*model_params, empty_init=False) + .to(dtype) + .to(device) + ) + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + + out = model(input) + + quantize_(model, config, cond_ffn_filter) + + if isinstance(config, MoEQuantConfig): + self.assertIsInstance(model.experts.w1, FakeExtraDimTensor) + if base_class is not None: + self.assertIsInstance(model.experts.w1.head_tensor, base_class) + if tensor_impl_class is not None: + self.assertIsInstance( + model.experts.w1.head_tensor.tensor_impl, tensor_impl_class + ) + else: + if base_class is not None: + self.assertIsInstance(model.experts.w1, base_class) + if tensor_impl_class is not None: + self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class) + + out_q = model(input) + + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph) + + model_c(input) + model_c(input) + out_qc = model_c(input).clone() + + for i in range(10): + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + model_c(input) + + self.assertGreaterEqual(compute_error(out_q, out), 10) + self.assertGreaterEqual(compute_error(out_qc, out), 10) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") + @parameterized.expand( + [ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ] + ) + def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Int4WeightOnlyConfig()) + tensor_impl_class = TensorCoreTiledAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_int4wo_base(self, name, num_tokens, fullgraph): + config = Int4WeightOnlyConfig() + tensor_impl_class = TensorCoreTiledAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") + @parameterized.expand( + [ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ] + ) + def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Int8WeightOnlyConfig()) + tensor_impl_class = PlainAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_int8wo_base(self, name, num_tokens, fullgraph): + config = Int8WeightOnlyConfig() + tensor_impl_class = PlainAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") + @parameterized.expand( + [ + ("multiple_tokens", 32, False), + ] + ) + def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + model_params=(512, 256, 2, 2), + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") + @parameterized.expand( + [ + ("multiple_tokens", 32, False), + ] + ) + def test_int8dq_base(self, name, num_tokens, fullgraph): + config = Int8DynamicActivationInt8WeightConfig() + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + model_params=(512, 256, 2, 2), + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + @parameterized.expand( + [ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ] + ) + def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Float8WeightOnlyConfig()) + tensor_impl_class = Float8AQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_fp8wo_base(self, name, num_tokens, fullgraph): + config = Float8WeightOnlyConfig() + tensor_impl_class = Float8AQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + @parameterized.expand( + [ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ] + ) + def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_fp8dq_base(self, name, num_tokens, fullgraph): + config = Float8DynamicActivationFloat8WeightConfig() + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py new file mode 100644 index 0000000000..c9cf8e6f37 --- /dev/null +++ b/torchao/_models/mixtral-moe/generate.py @@ -0,0 +1,483 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +from torchao.utils import get_model_size_in_bytes + +torch.manual_seed(0) + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif "cpu" in device: + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +torch._dynamo.config.capture_scalar_outputs = True + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import Transformer +from sentencepiece import SentencePieceProcessor + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[:, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + next_token, next_prob = next_token.clone(), next_prob.clone() + + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + batch_size: int, + *, + interactive: bool, + callback=lambda x: x, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + device, _ = prompt.device, prompt.dtype + + T = prompt.size(-1) + max_seq_length = ( + min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + ) + new_tokens = max_seq_length - T + + # duplicate prompt for batch_size + prompt = prompt.repeat(batch_size, 1) + + # create an empty tensor of the expected final shape and fill in the current tokens + seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) + seq[:, :T] = prompt + + with torch.device(device): + model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) + + input_pos = torch.arange(0, T, device=device) + next_token = prefill( + model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs + ) + seq[:, T] = next_token.squeeze() + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(batch_size, -1), + input_pos, + new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq = torch.cat((seq[:, : T + 1], *generated_tokens), dim=-1) + + return seq + + +def encode_tokens(tokenizer, string, bos=True, device="cuda"): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def _load_model(checkpoint_path, device, precision): + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path.parent.name) + + try: + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + except: + model = Transformer.from_name(checkpoint_path.parent.name) + + model = model.to(device=device, dtype=precision) + return model.eval() + + +B_INST, E_INST = "[INST]", "[/INST]" + + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + batch_size: int = 1, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), + compile: bool = True, + compile_prefill: bool = False, + moe_quant: Optional[str] = None, + profile: Optional[Path] = None, + memory_profile: Optional[Path] = None, + device="cuda", +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" + assert checkpoint_path.is_file(), checkpoint_path + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + print(f"Using device={device}") + precision = torch.bfloat16 + is_chat = "chat" in str(checkpoint_path) + + if device == "cuda" and memory_profile is not None: + torch.cuda.memory._record_memory_history( + True, trace_alloc_max_entries=500000, trace_alloc_record_context=True + ) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision) + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(model.parameters(), model.buffers()) + ] + ) + + from torchao.quantization.prototype.moe_quant.utils import ( + MoEQuantConfig, + cond_ffn_filter, + ) + from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8DynamicActivationIntxWeightConfig, + Int8WeightOnlyConfig, + PackedLinearInt8DynamicActivationIntxWeightLayout, + PerRow, + quantize_, + ) + + if moe_quant: + torch._dynamo.config.capture_dynamic_output_shape_ops = True + config = None + if "int8wo-base" in moe_quant: + config = Int8WeightOnlyConfig() + + elif "int8wo" in moe_quant: + config = MoEQuantConfig(Int8WeightOnlyConfig()) + + elif "int8dq-base" in moe_quant: + config = Int8DynamicActivationInt8WeightConfig() + + elif "int8dq" in moe_quant: + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + + elif "int4wo-base" in moe_quant: + config = Int4WeightOnlyConfig() + + elif "int4wo" in moe_quant: + config = MoEQuantConfig(Float8WeightOnlyConfig()) + + elif "fp8wo-base" in moe_quant: + config = Int4WeightOnlyConfig() + + elif "fp8wo" in moe_quant: + config = MoEQuantConfig(Float8WeightOnlyConfig()) + + elif "fp8dq-base" in moe_quant: + config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + + elif "fp8dq" in moe_quant: + config = MoEQuantConfig( + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + ) + + elif "intxdq" in moe_quant: + config = Int8DynamicActivationIntxWeightConfig( + layout=PackedLinearInt8DynamicActivationIntxWeightLayout() + ) + else: + assert config is not None, ( + f"expected moe_quant to match one of the options but got {moe_quant}" + ) + + if config is not None: + quantize_(model, config, filter_fn=cond_ffn_filter) + torch.cuda.reset_peak_memory_stats() + + if compile: + # moe quant + compile causes repeated warnings + import warnings + + warnings.simplefilter("ignore", lineno=84) + warnings.simplefilter("ignore", lineno=105) + + torch._inductor.config.assert_indirect_indexing = False + + global decode_one_token, prefill + + if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant): + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) + else: + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + + if args.compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + aggregate_metrics = { + "tokens_per_sec": [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode(".")[0] + done_generating = False + + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print("".join(buffer), end="", flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x: x + t0 = time.perf_counter() + import contextlib + + if i != num_samples - 1 or not profile: + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y = generate( + model, + encoded, + max_new_tokens, + batch_size, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + pass + print(tokenizer.decode(y[0].tolist())) + else: + print() + tokens_generated = y.size(-1) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" + ) + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + + if i == 0 and device == "cuda" and memory_profile is not None: + snapshot = torch.cuda.memory._snapshot() + with open(f"{memory_profile}.pickle", "wb") as f: + from pickle import dump + + dump(snapshot, f) + print( + f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", + "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html", + ) + break + + tokpersec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() + print(f"Average tokens/sec: {tokpersec:.2f}") + if batch_size > 1: + print(f"Average tokens/sec including batches {batch_size * tokpersec:.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + print(f"model size: {get_model_size_in_bytes(model) / 1e9:.02f}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Your CLI description.") + + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Whether to launch in interactive mode", + ) + parser.add_argument("--num_samples", type=int, default=5, help="Number of samples.") + parser.add_argument( + "--max_new_tokens", type=int, default=200, help="Maximum number of new tokens." + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size to benchmark with" + ) + parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling." + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + help="Model checkpoint path.", + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--compile_prefill", + action="store_true", + help="Whether to compile the prefill (improves prefill perf, but higher compile times)", + ) + # parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8') + parser.add_argument( + "--moe_quant", + type=str, + help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq", + ) + parser.add_argument("--profile", type=Path, default=None, help="Profile path.") + parser.add_argument( + "--memory_profile", type=Path, default=None, help="filename for memory profile." + ) + parser.add_argument("--device", type=str, default="cuda", help="device to use") + + args = parser.parse_args() + print(args) + main( + args.prompt, + args.interactive, + args.num_samples, + args.max_new_tokens, + args.batch_size, + args.top_k, + args.temperature, + args.checkpoint_path, + args.compile, + args.compile_prefill, + args.moe_quant, + args.profile, + args.memory_profile, + args.device, + ) diff --git a/torchao/_models/mixtral-moe/model.py b/torchao/_models/mixtral-moe/model.py new file mode 100644 index 0000000000..48dda5bef9 --- /dev/null +++ b/torchao/_models/mixtral-moe/model.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + +from torchao.quantization.prototype.moe_quant.utils import FakeExtraDimTensor + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + num_experts: int = 8 + num_activated_experts: int = 2 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [ + config + for config in transformer_configs + if config in str(name).upper() or config in str(name) + ] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "Mixtral-8x7B-Instruct-v0.1": dict( + block_size=32768, + n_layer=32, + n_head=32, + n_local_heads=8, + dim=4096, + intermediate_size=14336, + rope_base=1000000.0, + num_experts=8, + num_activated_experts=2, + ), +} + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim + ) + + self.freqs_cis = precompute_freqs_cis( + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + ) + self.causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.block_sparse_moe = MOEFeedForwardAOQuantizable(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.block_sparse_moe(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +# class ConditionalFeedForward(nn.Module): +# def __init__(self, config): +# super().__init__() +# self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) +# self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) +# self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + +# def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: +# w1_weights = self.w1[expert_indices] # [T, A, D, D] +# w3_weights = self.w3[expert_indices] # [T, A, D, D] +# w2_weights = self.w2[expert_indices] # [T, A, D, D] +# x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) +# x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) +# expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) +# return expert_outs + + +# class MOEFeedForward(nn.Module): +# def __init__(self, config) -> None: +# super().__init__() +# self.gate = nn.Linear(config.dim, config.num_experts, bias=False) +# self.cond_ffn = ConditionalFeedForward(config) +# self.dim = config.dim +# self.num_activated_experts = config.num_activated_experts +# def forward(self, x: Tensor) -> Tensor: +# x = x.view(-1, self.dim) +# # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts +# # x: [T, D] +# scores = self.gate(x) # [T, E] +# expert_weights = F.softmax(scores, dim=-1) +# expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] +# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] +# expert_outs = self.cond_ffn(x, expert_indices) +# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForwardAOQuantizable(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.dim) # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk( + expert_weights, self.num_activated_experts, dim=-1 + ) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + out = self.cond_ffn( + x, expert_indices, expert_weights, self.num_activated_experts + ) + return out.reshape(batch_size, -1, self.dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.w1 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) # E, I, D + self.w2 = nn.Parameter( + torch.empty(config.num_experts, config.dim, config.intermediate_size) + ) # E, D, I + self.w3 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) # E, I, D + self.num_experts = config.num_experts + + def forward( + self, + x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + if x.shape[0] == 1 and not isinstance( + self.w1, FakeExtraDimTensor + ): # only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices = expert_indices.view(num_activated_experts) + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + y1 = F.silu(F.linear(x, w1[index])) + y3 = F.linear(x, w3[index]) + y2 = w2[index] + cur_out = F.linear(y1 * y3, y2) + outs.append(cur_out) + + # combine outputs + final_out = ( + (torch.cat(outs, dim=0) * expert_weights.view(-1, 1)) + .sum(dim=0) + .unsqueeze(-1) + ) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort( + stable=True + ) # [A] + ordered_token_indices = ( + ordered_token_activations.div(num_activated_experts) + .floor() + .to(torch.int64) + ) # [T] + + num_tokens_per_expert = torch.histc( + expert_indices, bins=self.num_experts + 1, min=-1, max=self.num_experts + ) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( + torch.int64 + ) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert, expert_list + ): + token_indices_per_expert = [ + ordered_token_indices[ + cum_tokens_per_expert[expert] : cum_tokens_per_expert[ + expert + 1 + ] + ] + for expert in expert_list + ] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + + token_indices_per_expert = group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert, expert_list + ) + tokens_grouped_by_expert = [ + x[indices] for indices in token_indices_per_expert + ] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert, expert_list): + w1 = self.w1[expert] # I, D + w2 = self.w2[expert] # D, I + w3 = self.w3[expert] # I, D + + cur_out = F.linear( + F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2 + ) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1, 1)[ + ordered_token_activations + ].view(-1, 1) # [T*A, 1] + weighted_ordered_outs = ( + ordered_outs * ordered_token_activation_weights + ) # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add( + dim=0, + index=ordered_token_indices.unsqueeze(-1) + .expand(num_token_activations, dim) + .to(torch.int64), + src=weighted_ordered_outs, + ) + return final_out diff --git a/torchao/_models/mixtral-moe/run.sh b/torchao/_models/mixtral-moe/run.sh new file mode 100644 index 0000000000..692acfa1c4 --- /dev/null +++ b/torchao/_models/mixtral-moe/run.sh @@ -0,0 +1,39 @@ +export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 +export CHECKPOINT_PATH=~/checkpoints/ + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo-base --compile + +# EXPERT CHOICE +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq --compile +# # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq --compile +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq-base --compile +# # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq-base --compile + +# ARM +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --device cpu +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --compile --device cpu diff --git a/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py b/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py new file mode 100644 index 0000000000..6a39578e32 --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import glob +import re +import sys +from pathlib import Path +from typing import Optional + +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import ModelArgs + + +@torch.inference_mode() +def convert_hf_checkpoint( + *, + checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), + model_name: Optional[str] = None, +) -> None: + if model_name is None: + model_name = checkpoint_dir.name + + config = ModelArgs.from_name(model_name) + print(f"Model config {config.__dict__}") + + weight_map = { + "tok_embeddings.weight": "tok_embeddings.weight", + "layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight", + "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight", + "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight", + "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight", + "layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1", + "layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2", + "layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3", + "layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight", + "layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight", + "layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight", + "norm.weight": "norm.weight", + "output.weight": "output.weight", + } + + pt_files = glob.glob(str(checkpoint_dir / "*.pt")) + + merged_result = {} + for file in sorted(pt_files): + state_dict = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) + merged_result.update(state_dict) + final_result = {} + for key, value in merged_result.items(): + if "layers" in key: + abstract_key = re.sub(r".(\d+).", ".{}.", key) + layer_num = re.search(r"\d+", key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wq" in key: + q = final_result[key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] + final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) + del final_result[key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] + elif "w1" in key or "w3" in key: + final_result[key] = ( + final_result[key] + .reshape(config.num_experts, config.intermediate_size, config.dim) + .contiguous() + ) + elif "w2" in key: + final_result[key] = ( + final_result[key] + .reshape(config.num_experts, config.intermediate_size, config.dim) + .permute(0, 2, 1) + .contiguous() + ) + elif "gate" in key: + final_result[key] = final_result[key].contiguous() + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") + parser.add_argument( + "--checkpoint_dir", + type=Path, + default=Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), + ) + parser.add_argument("--model_name", type=str, default=None) + + args = parser.parse_args() + convert_hf_checkpoint( + checkpoint_dir=args.checkpoint_dir, + model_name=args.model_name, + ) diff --git a/torchao/_models/mixtral-moe/scripts/download.py b/torchao/_models/mixtral-moe/scripts/download.py new file mode 100644 index 0000000000..ec9077d3dc --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/download.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +from typing import Optional + +from requests.exceptions import HTTPError + + +def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: + from huggingface_hub import snapshot_download + + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) + try: + snapshot_download( + repo_id, + local_dir=f"checkpoints/{repo_id}", + local_dir_use_symlinks=False, + token=hf_token, + ignore_patterns="*.safetensors", + ) + except HTTPError as e: + if e.response.status_code == 401: + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) + else: + raise e + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") + parser.add_argument( + "--repo_id", + type=str, + default="checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1", + help="Repository ID to download from.", + ) + parser.add_argument( + "--hf_token", type=str, default=None, help="HuggingFace API token." + ) + + args = parser.parse_args() + hf_download(args.repo_id, args.hf_token) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index bf1bdacb68..0283e67760 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -476,12 +476,15 @@ def _(func, types, args, kwargs): shape = list(self.shape) shape[dim] = end - start block_size = self.block_size - assert len(block_size) == 2, ( - f"Slice only works for 2d block_size right now, got: {block_size}" - ) + assert len(block_size) in [ + 2, + 3, + ], f"Slice only works for 2 and 3d block_size right now, got: {block_size}" # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow - block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + if len(block_size) == 2: + block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + new = self.__class__( aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, @@ -490,7 +493,54 @@ def _(func, types, args, kwargs): self.quant_max, self.zero_point_domain, dtype=self.dtype, - strides=self.stride(), + strides=self.stride() if len(block_size) == 2 else None, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.index.Tensor) +def _(func, types, args, kwargs): + self, indices = args + assert len(indices) == 1, ( + f"op {func} currently only implemented for single dimensional indexing but got indices: {indices}" + ) + + new_tensor_impl = aten.index.Tensor(self.tensor_impl, indices) + shape = tuple([indices[0].numel(), *self.shape[1:]]) + + block_size = self.block_size + new = self.__class__( + new_tensor_impl, + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + self, dim, index = fill_defaults(args, 3, [0, 0]) + assert dim == 0, f"op {func} currently only implemented for dim=0 but got dim={dim}" + assert self.dim() == 3, ( + f"op {func} currently only implemented for 3 dimensional tensors but got shape={self.shape}" + ) + + new_tensor_impl = aten.select.int(self.tensor_impl, dim, index) + + shape = self.shape[1:] + block_size = self.block_size[1:] + new = self.__class__( + new_tensor_impl, + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, ) return return_and_correct_aliasing(func, args, kwargs, new) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 872179bd9a..29c4bf41b0 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -159,6 +159,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + elif func in [aten.select.int, func is aten.index.Tensor]: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index 516136bca7..3551214d7e 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -154,6 +154,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) return return_and_correct_aliasing(func, args, kwargs, new) + elif func in [aten.select.int, aten.index.Tensor]: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 9c37f58ada..3bf9ef6b72 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -93,11 +93,13 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # groupwise int4 quantization - groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) - + groupsize = weight_tensor.block_size[-1] + if act_mat.numel() == 0: # handling for empty input + y = act_mat + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] @@ -119,7 +121,7 @@ class TensorCoreTiledLayout(Layout): inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape + orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( @@ -160,18 +162,18 @@ def post_process( zero_point: torch.Tensor, block_size: Tuple[int, ...], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - orig_out_features, orig_in_features = input.shape + orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( input, (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) - assert len(block_size) == 2, ( - f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}" + assert len(block_size) == 2 or len(block_size) == 3, ( + f"TensorCoreTiledLayout only supports len(block_size) == 2 or 3, got: {block_size}" ) - scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0] - scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1] + scale_pad_dim_0 = (out_features - orig_out_features) // block_size[-2] + scale_pad_dim_1 = (in_features - orig_in_features) // block_size[-1] scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0)) zero_point = torch.nn.functional.pad( zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0) @@ -262,21 +264,44 @@ def from_plain( _layout: Layout, ): assert isinstance(_layout, TensorCoreTiledLayout) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + ) + + def quant_2d(int_data_2d): + if TORCH_VERSION_AT_LEAST_2_5: + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( + torch.uint8 + ) + else: + assert int_data_2d.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + ) + return torch.ops.aten._convert_weight_to_int4pack( + int_data_2d.contiguous(), _layout.inner_k_tiles + ) - if TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + if int_data.dim() == 3: # for moe quant + num_experts = int_data.shape[0] + packed_weight_list = [] + for expert in range(num_experts): + packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0)) + packed_weight = torch.cat(packed_weight_list, dim=0) + scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1) + zero_point = ( + zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1) + if zero_point is not None + else None ) else: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + assert int_data.dim() == 2 + packed_weight = quant_2d(int_data) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = ( + zero_point.reshape(int_data.shape[0], -1) + if zero_point is not None + else None ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) @@ -336,6 +361,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + if func in [aten.select.int, aten.index.Tensor]: + assert not (func is aten.select.int and args[1] != 0), ( + "aten.select.int currently only has support for dim=0" + ) + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + if func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose @@ -416,11 +452,16 @@ def block_size(self): scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape - assert len(cur_shape) == 4 + if len(cur_shape) == 5: + ones = [1, 1] + cur_shape = cur_shape[1:] + else: + assert len(cur_shape) == 4 + ones = [1] inner_k_tiles = cur_shape[-1] * 2 original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) groupsize = int(original_shape[1] / scale.shape[-2]) - return (1, groupsize) + return tuple([*ones, groupsize]) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( @@ -429,35 +470,50 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros - scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + def dequant_4d(self): + cur_shape = self.shape + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + assert len(cur_shape) == 4 + inner_k_tiles = cur_shape[-1] * 2 + original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + original_dtype = torch.bfloat16 + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=self.device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + return dequantized cur_shape = self.shape - assert len(cur_shape) == 4 - inner_k_tiles = cur_shape[-1] * 2 - original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) - eye_shape = original_shape[1] - groupsize = int(original_shape[1] / scale.shape[-2]) - block_size = (1, groupsize) - device = self.device - original_dtype = torch.bfloat16 + + if len(cur_shape) == 4: + dequantized = dequant_4d(self) + else: + assert len(cur_shape) == 5 + num_experts = cur_shape[0] + dequantized_list = [] + for expert in range(num_experts): + dequantized_list.append(dequant_4d(self[expert]).unsqueeze(0)) + dequantized = torch.cat(dequantized_list, dim=0) + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + target_dtype = torch.int32 quant_min = 0 quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT - assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm( - torch.eye(eye_shape, device=device, dtype=original_dtype), - self.packed_weight, - groupsize, - self.scale_and_zero, - ) - dequantized = dequantized.t().contiguous() - # TODO: move this to `unpack_tinygemm_scales_and_zeros`? - scale = scale.reshape(scale.shape[:-1]).contiguous() - zero = zero.reshape(zero.shape[:-1]).contiguous() int_data = quantize_affine( dequantized, - block_size, + self.block_size, scale, zero, target_dtype, diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index da6c98cd6f..38db685f1c 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -629,6 +629,40 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer( sqnr2 = compute_error(prepared_out, converted_out2).item() self.assertTrue(sqnr2 == float("inf")) + def test_moe_quant_intx(self): + from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import ( + MOEFeedForwardAOQuantizable, + ) + from torchao.quantization.prototype.moe_quant.utils import ( + FakeExtraDimTensor, + MoEQuantConfig, + cond_ffn_filter, + ) + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + PackedLinearInt8DynamicActivationIntxWeightLayout, + quantize_, + ) + from torchao.quantization.utils import compute_error + + with torch.device("cpu"): + model = MOEFeedForwardAOQuantizable(512, 256, 8, 2).to(torch.bfloat16) + x = torch.randn(1, 512, dtype=torch.bfloat16) + + out = model(x).clone() + + base_config = Int8DynamicActivationIntxWeightConfig( + layout=PackedLinearInt8DynamicActivationIntxWeightLayout() + ) + moe_config = MoEQuantConfig(base_config) + + quantize_(model, moe_config, cond_ffn_filter) + + out_q = model(x).clone() + assert isinstance(model.experts.w1, FakeExtraDimTensor) + + assert compute_error(out_q, out) > 30, "error bad accuracy but everything ran" + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index e4343a086f..aa946c064f 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -82,6 +82,8 @@ def __tensor_unflatten__( def _quantized_linear_op( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor ): + if input_tensor.numel() == 0: + return input_tensor input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor quant_kwargs = weight_tensor.quant_kwargs @@ -243,6 +245,34 @@ def _(func, types, args, kwargs): ) +@implements(aten.select.int) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + LinearActivationQuantizedTensor( + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, + ), + ) + + +@implements(aten.index.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + LinearActivationQuantizedTensor( + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, + ), + ) + + # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) def _(func, types, args, kwargs): diff --git a/torchao/quantization/prototype/moe_quant/README.md b/torchao/quantization/prototype/moe_quant/README.md new file mode 100644 index 0000000000..c80fb679a1 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/README.md @@ -0,0 +1,45 @@ +# MoE Quantization + +Our goal with this prototype implementation of moe quantization is to enable usage of existing linear quantization techniques for moe quantization. While it would likely be more performant to use a fused kernel for quantized moe, by decomposing the moe operation into a sequence of linear operations, we can utilize the existing tools and UX that work for lienar quantization and apply them to moe. + +Examples of the usage of these apis can be found in both the llama4_quant.py and ao/torchao/_models/mixtral-moe/generate.py + +## Quantization API + +The API for moe quantization is very similar to linear quantization, given a moe module that is decomposed into linear operations, is quantizable and compilable. In practice this requires us to use the modules found in quantizable_moe_modules.py or something similar. Once this change has been made the API is as follows for a few different quantization techniques: + +```python + +from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter +from torchao.quantization.quant_api import quantize_, Int8WeightOnlyConfig + +quantize_(model, Int8WeightOnlyConfig(), filter_fn=cond_ffn_filter) +model=torch.compile(model, mode="reduce-overhead") +# you can also use fullgraph=True for single token inference +``` + +This api is the same as for normal linear quantization but with a specific filter function. This works for several different quantization techniques where the quantized tensor subclass has been adapted to work with 3D tensors. Specifically this means Int8WeightOnlyConfig, Int4WeightOnlyConfig, Int4WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig, and Int8DynamicActivationInt8WeightConfig. It should be noted that due to the requirements on minimum tensor input size (>16), Int8DynamicActivationInt8WeightConfig is best used for expert choice moe rather than token choice which is what the rest of the framework in this folder supports. + + +## Alternative Quantization API + +To make the above api work, each tensor subclass had to be edited to work as 3D tensors. However the only ops we actually need to support are a few indexing and slicing ops on the 0th dimension, the majority of the work was removing hard coded assumptions about the tensor dimensionality. This means its possible to instead create a new tensor subclass that pretends to be a 3D tensor by storing a series of 2D tensors and simulating the slicing and indexing ops until eventually just returning the singular desired 2D quantized tensor subclass. This can be achieved using the alternative api as follows: + +```python + +from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig +from torchao.quantization.quant_api import quantize_, Int8DynamicActivationIntxWeightConfig + +config = MoEQuantConfig(Int8DynamicActivationIntxWeightConfig()) + +quantize_(model, , filter_fn=cond_ffn_filter) +model=torch.compile(model, mode="reduce-overhead") +``` + +While this approach turns out to not be especially performant, it does allow for comparable memory characteristics, allowing models that wouldn't fit on a single node/gpu to actually run. It is flexible enough however to work with all of the existing linear quantization techniques that make use of quantized tensor subclasses without any changes being made to those classes. It is compilable though even single token inference doesn't work with fullgraph compilation. + +## Model API + +In practice the moe implementations of known models tend to not be easy to quantize and even of those that are, they are often either compiled with many graph breaks or impossible to torch.compile at all. + +The modules in the quantizable_moe_modules.py file were carefully written to satisfy both of those necessary characteristics but to apply moe quantization to your own model, it will require first a module swap from the existing MoE module type, to these more flexible ones. While there isn't a one size fits all way to do this, an example of how it was done for huggingface's llama4 implementation can be found in llama4_quant.py which can be seen as a proof of concept. diff --git a/torchao/quantization/prototype/moe_quant/__init__.py b/torchao/quantization/prototype/moe_quant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/prototype/moe_quant/llama4_quant.py b/torchao/quantization/prototype/moe_quant/llama4_quant.py new file mode 100644 index 0000000000..d68c0d28c1 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/llama4_quant.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + +import torch +import torch.nn as nn +from transformers import AutoTokenizer, Llama4ForCausalLM +from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + +from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import ( + MOEFeedForwardAOQuantizable, +) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + +def llama4_moe_filter_fn(module, fqn): + return isinstance(module, Llama4TextMoe) + + +def convert_fn(module): + # get data + hidden_dim = module.hidden_dim + expert_dim = module.experts.expert_dim + num_experts = module.num_experts + top_k = module.top_k + act_fn = module.experts.act_fn + shared_expert = module.shared_expert + return_scores = True + new_mod = MOEFeedForwardAOQuantizable( + hidden_dim, + expert_dim, + num_experts, + top_k, + act_fn, + shared_expert, + return_scores, + ) + + router = module.router + up_proj = module.experts.gate_up_proj + w1, w3 = up_proj.permute(0, 2, 1).chunk(2, dim=1) + w2 = module.experts.down_proj.permute(0, 2, 1) + + new_mod.router = router + new_mod.experts.w1 = nn.Parameter(w1, requires_grad=False) + new_mod.experts.w2 = nn.Parameter(w2, requires_grad=False) + new_mod.experts.w3 = nn.Parameter(w3, requires_grad=False) + return new_mod + + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +model = Llama4ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +_replace_with_custom_fn_if_matches_filter( + model, + convert_fn, + llama4_moe_filter_fn, +) + +model = model + +from torchao.quantization import Int4WeightOnlyConfig, quantize_ +from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter + +quantize_(model, Int4WeightOnlyConfig(), cond_ffn_filter, device="cuda") + +model.cuda() + +model = torch.compile(model, mode="reduce-overhead") + +prompt = "He is here, the one who will tear apart the very stars" +inputs = tokenizer(prompt, return_tensors="pt") +model.generate(inputs.input_ids.cuda(), max_length=30) +model.generate(inputs.input_ids.cuda(), max_length=30) +generate_ids = model.generate(inputs.input_ids.cuda(), max_length=50) +out = tokenizer.batch_decode( + generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +)[0] +print(out) diff --git a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py new file mode 100644 index 0000000000..ec41a8246d --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py @@ -0,0 +1,186 @@ +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from torchao.quantization.prototype.moe_quant.utils import FakeExtraDimTensor + + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__( + self, + hidden_dim, + expert_dim, + num_experts, + top_k, + act_fn=F.silu, + shared_expert=None, + return_scores=False, + empty_init=True, + ) -> None: + super().__init__() + self.router = nn.Linear(hidden_dim, num_experts, bias=False) + self.experts = ConditionalFeedForwardAOQuantizable( + num_experts, hidden_dim, expert_dim, act_fn, empty_init + ) + self.hidden_dim = hidden_dim + self.top_k = top_k + self.shared_expert = shared_expert + self.return_scores = return_scores + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.hidden_dim) # x: [T, D] + scores = self.router(x) # [T, E] + scores = F.softmax(scores, dim=-1) + scores, expert_indices = torch.topk( + scores, self.top_k, dim=-1 + ) # [T, A], [T, A] + scores /= scores.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + + out = self.experts(x, expert_indices, scores, self.top_k) + if self.shared_expert: + out += self.shared_expert(x) + + if self.return_scores: + return out.reshape(batch_size, -1, self.hidden_dim), scores + else: + return out.reshape(batch_size, -1, self.hidden_dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, num_experts, hidden_dim, expert_dim, act_fn, empty_init=True): + super().__init__() + if empty_init: + self.w1 = nn.Parameter( + torch.empty(num_experts, expert_dim, hidden_dim) + ) # E, I, D + self.w2 = nn.Parameter( + torch.empty(num_experts, hidden_dim, expert_dim) + ) # E, D, I + self.w3 = nn.Parameter( + torch.empty(num_experts, expert_dim, hidden_dim) + ) # E, I, D + else: + self.w1 = nn.Parameter( + torch.randn(num_experts, expert_dim, hidden_dim) + ) # E, I, D + self.w2 = nn.Parameter( + torch.randn(num_experts, hidden_dim, expert_dim) + ) # E, D, I + self.w3 = nn.Parameter( + torch.randn(num_experts, expert_dim, hidden_dim) + ) # E, I, D + self.num_experts = num_experts + self.act_fn = act_fn + self.hidden_dim = hidden_dim + self.expert_dim = expert_dim + + def forward( + self, + x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + top_k: int, + ) -> Tensor: + num_tokens, _hidden_dim = x.shape + num_token_activations = num_tokens * top_k + + if x.shape[0] == 1 and not isinstance( + self.w1, FakeExtraDimTensor + ): # only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices = expert_indices.view(top_k) + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + # run token through each expert + for index in range(top_k): + y1 = F.silu(F.linear(x, w1[index])) + y3 = F.linear(x, w3[index]) + y2 = w2[index] + + cur_out = F.linear(y1 * y3, y2) + outs.append(cur_out) + + # combine outputs + final_out = ( + (torch.cat(outs, dim=0) * expert_weights.view(-1, 1)) + .sum(dim=0) + .reshape(x.shape) + ) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort( + stable=True + ) # [A] + ordered_token_indices = ( + ordered_token_activations.div(top_k).floor().to(torch.int64) + ) # [T] + num_tokens_per_expert = torch.histc( + expert_indices, + bins=self.num_experts + 1, + min=-1, + max=self.num_experts, + ) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( + torch.int64 + ) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert, expert_list + ): + token_indices_per_expert = [ + ordered_token_indices[ + cum_tokens_per_expert[expert] : cum_tokens_per_expert[ + expert + 1 + ] + ].to(torch.int64) + for expert in expert_list + ] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + + token_indices_per_expert = group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert, expert_list + ) + tokens_grouped_by_expert = [ + x[indices] for indices in token_indices_per_expert + ] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert, expert_list): + w1 = self.w1[expert] # I, D + w2 = self.w2[expert] # D, I + w3 = self.w3[expert] # I, D + + y1 = F.silu(F.linear(cur_x, w1)) + y3 = F.linear(cur_x, w3) + y2 = w2 + + cur_out = F.linear(y1 * y3, y2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1, 1)[ + ordered_token_activations + ].view(-1, 1) # [T*A, 1] + weighted_ordered_outs = ( + ordered_outs * ordered_token_activation_weights + ) # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add( + dim=0, + index=ordered_token_indices.unsqueeze(-1) + .expand(num_token_activations, self.hidden_dim) + .to(torch.int64), + src=weighted_ordered_outs, + ) + return final_out diff --git a/torchao/quantization/prototype/moe_quant/utils.py b/torchao/quantization/prototype/moe_quant/utils.py new file mode 100644 index 0000000000..22e786bc3c --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/utils.py @@ -0,0 +1,261 @@ +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +aten = torch.ops.aten + +from typing import List, Optional, Tuple, Union + +from torchao.quantization.quant_api import ( + AOBaseConfig, + dataclass, + register_quantize_module_handler, +) +from torchao.utils import fill_defaults + + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + super().__init__() + self.weight = weight + self.bias = bias + + +class FakeExtraDimTensor(torch.Tensor): + """This is a subclass of torch.Tensor that simulates a tensor of n+1 dimensions, akin to concatenating several tensors along the 0th dimension. + It takes a list of tensors with the same dtype, device and shape and creates a representation of shape (num_tensors, orig_shape). It can handle a + variety of ops like detach and clone but most importantly, supports any slicing and indexing along the extra dimension. + This is most useful when you have another tensor subclass that you'd like to concatenate together but don't want to support all the necessary + pieces of 3D scaffolding required to make it work. + + The structure of this tensor subclass is a linked_list of tensors with each instance of FakeExtraDimTensor containing a head tensor and a tail consisting of + either another intance of FakeExtraDimTensor or None if we've reached the end of the linked list. This implementation structure is necessary to + support compilation of this tensor subclass since compile requires each tensor component of the tensor subclass to have its own attribute. + """ + + def __new__( + cls, + tensors: Union[Tuple[torch.Tensor], List[torch.Tensor]], + tensor_tail: Optional["FakeExtraDimTensor"] = None, + ): + assert len(tensors) > 0 or tensor_tail is not None + num_tensors = len(tensors) + if tensor_tail is not None: + num_tensors += tensor_tail.num_tensors + test_tensor = tensor_tail.head_tensor + else: + test_tensor = tensors[0] + + dtype = test_tensor.dtype + shape = test_tensor.shape + device = test_tensor.device + layout = test_tensor.layout + for tensor in tensors: + assert tensor.dtype == dtype, ( + f"all tensors in FakeExtraDimTensor must have same dtype but got {tensor.dtype} and {dtype}" + ) + assert tensor.shape == shape, ( + f"all tensors in FakeExtraDimTensor must have same shape but got {tensor.shape} and {shape}" + ) + assert tensor.device == device, ( + f"all tensors in FakeExtraDimTensor must have same device but got {tensor.device} and {device}" + ) + assert tensor.layout == layout, ( + f"all tensors in FakeExtraDimTensor must have same layout but got {tensor.layout} and {layout}" + ) + kwargs = {} + kwargs["dtype"] = dtype + kwargs["layout"] = layout + kwargs["device"] = device + kwargs["requires_grad"] = False + new_shape = (num_tensors, *shape) + return torch.Tensor._make_wrapper_subclass(cls, new_shape, **kwargs) + + def __repr__( + self, + ): + return f"{self.__class__.__name__}(shape={self.shape}, containing {self.num_tensors}: {self.head_tensor})" + + def __init__( + self, + tensors: Union[Tuple[torch.Tensor], List[torch.Tensor]], + tensor_tail: Optional["FakeExtraDimTensor"] = None, + ): + tensors = list(tensors) + assert len(tensors) > 0 or tensor_tail is not None + + # count num_tensors and make tensor_list + self.num_tensors = len(tensors) + if tensor_tail is not None: + self.num_tensors += tensor_tail.num_tensors + tail_list = tensor_tail.tensor_list + else: + tail_list = [] + self.tensor_list = tensors + tail_list + + # 3 cases + # 0) tensors has 0 elements -> take element from tail then do case 1 instead + # 1) tensors has 1 element, -> pop element and tail is None + # 2) tensors has >1 elements, -> pop element and recurse + + # convert case 0 to case 1 by taking 1 element from tail + if len(tensors) == 0 and tensor_tail is not None: + tensors = [ + tensor_tail.head_tensor, + ] + tensor_tail = tensor_tail.tensor_tail + + if len(tensors) > 1: + # case (1): remove first element from tensors, then recurse + self.head_tensor = tensors[0] # remove one + self.tensor_tail = self.__class__(tensors[1:], tensor_tail) # recurse + elif len(tensors) == 1: + # case (2) take final element from tensors, attach tensor_tail then stop recursion + self.head_tensor = tensors[0] + self.tensor_tail = tensor_tail + + def _apply_fn_to_data(self, fn): + self.head_tensor = fn(self.head_tensor) + if self.tensor_tail is not None: + self.tensor_tail = self.tensor_tail._apply_fn_to_data(fn) + return self.__class__([self.head_tensor], self.tensor_tail) + + def __tensor_flatten__(self): + if self.tensor_tail is None: + return [ + "head_tensor", + ], [self.num_tensors] + else: + return [ + "head_tensor", + "tensor_tail", + ], [self.num_tensors] + + @classmethod + def __tensor_unflatten__( + cls, + tensor_data_dict, + tensor_attributes, + outer_size, + outer_stride, + ): + head_tensor = tensor_data_dict["head_tensor"] + tensor_tail = tensor_data_dict.get("tensor_tail", None) + return cls([head_tensor], tensor_tail) + + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + kwargs = {} if kwargs is None else kwargs + if func is torch.nn.functional.linear: + x, w, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert w.num_tensors == 1, ( + "FakeExtraDimTensor used in a linear op when it had multiple tensors" + ) + return func(x, w.head_tensor, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except Exception as e: + print(f"ERR: subclass {cls} doesn't implement {func}, got error: {e}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func == aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, cls(self.tensor_list[start:end:step]) + ) + + elif func == aten.select.int: + self, dim, index = fill_defaults(args, 3, [0, 0]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, cls([self.tensor_list[index]]) + ) + elif func == aten.index.Tensor: + self, indices, dim = fill_defaults(args, 3, [0]) + if dim == 0: + # this handles a weird bug where indices gets turned into a list + # between the function dispatch and torch dispatch but just for this function + if isinstance(indices, list) and len(indices) == 1: + indices = indices[0] + return return_and_correct_aliasing( + func, + args, + kwargs, + cls([self.tensor_list[index] for index in indices]), + ) + try: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + except Exception as e: + print( + f"function {func} failed for FakeExtraDimTensor, following error occured when trying to" + "run function on its elements: " + ) + raise e + + +@dataclass +class MoEQuantConfig(AOBaseConfig): + """Configuration for applying quantization to MoE + Args: + `base_config`: normal AO Config + """ + + base_config: AOBaseConfig + + +@register_quantize_module_handler(MoEQuantConfig) +def moe_quant_fn(module, config: MoEQuantConfig): + import warnings + + warnings.simplefilter("ignore", lineno=84) + warnings.simplefilter("ignore", lineno=105) + assert "ConditionalFeedForwardAOQuantizable" in str(type(module)) + from torchao.quantization.quant_api import _QUANTIZE_CONFIG_HANDLER + + for weight_attr in ["w1", "w2", "w3"]: + param = getattr(module, weight_attr) + assert isinstance(config.base_config, AOBaseConfig), ( + f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}" + + "this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())" + ) + handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + + # break 3D tensor + tensors = [param[i] for i in range(param.shape[0])] + # put tensors into modules since the handlers target modules not tensors + dummy_modules = [DummyModule(tensor) for tensor in tensors] + # apply handler to each module + out_mods = list(map(lambda x: handler(x, config.base_config), dummy_modules)) + # pack quantized subclasses into FakeExtraDimTensor + new_param = FakeExtraDimTensor([mod.weight for mod in out_mods]) + new_param = torch.nn.Parameter(new_param, requires_grad=False) + setattr(module, weight_attr, new_param) + del param + return module + + +def moe_filter(module, fqn): + return "MOEFeedForwardAOQuantizable" in str(type(module)) + + +def cond_ffn_filter(module, fqn): + return "ConditionalFeedForwardAOQuantizable" in str(type(module)) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 890c2e2038..471af07858 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -300,7 +300,7 @@ def _replace_with_custom_fn_if_matches_filter( device, extra_args, ) - if new_child is not child: + if new_child is not child and new_child is not None: setattr(model, name, new_child) if device is not None: model.to(device=device) # move parent module to device @@ -1050,31 +1050,25 @@ class Int4WeightOnlyConfig(AOBaseConfig): int4_weight_only = Int4WeightOnlyConfig -@register_quantize_module_handler(Int4WeightOnlyConfig) -def _int4_weight_only_transform( - module: torch.nn.Module, config: Int4WeightOnlyConfig -) -> torch.nn.Module: +def _int4_weight_only_quantize_tensor(weight, config): # TODO(future PR): perhaps move this logic to a different file, to keep the API # file clean of implementation details # for now, make these local variables to allow the rest of the function # to be a direct copy-paste - weight = module.weight group_size = config.group_size layout = config.layout use_hqq = config.use_hqq zero_point_domain = config.zero_point_domain - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() if weight.shape[-1] % group_size != 0: logger.info( f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) - return module + return weight mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) + block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size]) target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -1126,9 +1120,32 @@ def _int4_weight_only_transform( _layout=layout, use_hqq=use_hqq, ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + return new_weight + + +@register_quantize_module_handler(Int4WeightOnlyConfig) +def _int4_weight_only_transform( + module: torch.nn.Module, config: Int4WeightOnlyConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + new_weight = _int4_weight_only_quantize_tensor(weight, config) + new_weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int4_weight_only_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass @@ -1145,20 +1162,15 @@ class Int8WeightOnlyConfig(AOBaseConfig): int8_weight_only = Int8WeightOnlyConfig -@register_quantize_module_handler(Int8WeightOnlyConfig) -def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): - group_size = config.group_size - weight = module.weight - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - +def _int8_weight_only_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 + group_size = config.group_size if group_size is None: - group_size = weight.shape[1] - block_size = (1, group_size) + group_size = weight.shape[-1] + block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) new_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1167,9 +1179,30 @@ def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyC eps=eps, zero_point_dtype=zero_point_dtype, ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + return new_weight + + +@register_quantize_module_handler(Int8WeightOnlyConfig) +def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + new_weight = _int8_weight_only_quantize_tensor(weight, config) + new_weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int8_weight_only_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: @@ -1283,33 +1316,26 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig -@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) -def _int8_dynamic_activation_int8_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig -) -> torch.nn.Module: +def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): layout = config.layout act_mapping_type = config.act_mapping_type weight_only_decode = config.weight_only_decode - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - in_features = weight.shape[1] + in_features = weight.shape[-1] # int8 dynamic quantization only has benefit when in_feature > 16 if in_features <= 16: logger.info( f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" f" because `in_feature` is <= 16: {in_features}" ) - return module + return weight # weight settings mapping_type = MappingType.SYMMETRIC weight_zero_point_domain = ZeroPointDomain.NONE def get_weight_block_size(x): - return (1, x.shape[1]) + return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]]) target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -1325,7 +1351,7 @@ def get_weight_block_size(x): input_quant_func = _int8_asymm_per_token_quant block_size = get_weight_block_size(weight) - weight = to_affine_quantized_intx( + new_weight = to_affine_quantized_intx( weight, mapping_type, block_size, @@ -1335,9 +1361,36 @@ def get_weight_block_size(x): _layout=layout, zero_point_domain=weight_zero_point_domain, ) - weight = to_linear_activation_quantized(weight, input_quant_func) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + new_weight = to_linear_activation_quantized(new_weight, input_quant_func) + return new_weight + + +@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) +def _int8_dynamic_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + new_weight = _int8_dynamic_activation_int8_weight_quantize_tensor( + weight, config + ) + new_weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + "applying int8 dynamic activation int8 weight quant requires module to have weight attribute" + + "but {module} does not have one" + ) + new_weight = _int8_dynamic_activation_int8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -1375,17 +1428,10 @@ class Float8WeightOnlyConfig(AOBaseConfig): float8_weight_only = Float8WeightOnlyConfig -@register_quantize_module_handler(Float8WeightOnlyConfig) -def _float8_weight_only_transform( - module: torch.nn.Module, config: Float8WeightOnlyConfig -) -> torch.nn.Module: +def _float8_weight_only_quant_tensor(weight, config): from torchao.dtypes import to_affine_quantized_floatx - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - block_size = (1, weight.shape[1]) + block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) new_weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, @@ -1393,9 +1439,33 @@ def _float8_weight_only_transform( scale_dtype=None, _layout=Float8Layout(mm_config=None), ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + return new_weight + + +@register_quantize_module_handler(Float8WeightOnlyConfig) +def _float8_weight_only_transform( + module: torch.nn.Module, config: Float8WeightOnlyConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + new_weight = _float8_weight_only_quant_tensor(weight, config) + new_weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _float8_weight_only_quant_tensor(module.weight, config) + + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module _fp8_granularities = Union[PerTensor, PerRow] @@ -1496,11 +1566,12 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool: Returns: bool: True if the tensor can be quantized to float8, False otherwise """ - assert weight.dim() == 2, ( - f"float8 quantization only works for 2-D tensors, got {weight.dim()}D tensor" - ) + assert weight.dim() in [ + 2, + 3, + ], f"float8 quantization only works for 2/3-D tensors, got {weight.dim()}D tensor" - out_dim, in_dim = weight.shape + out_dim, in_dim = weight.shape[-2:] is_compatible = (in_dim % 16 == 0) and (out_dim % 16 == 0) if not is_compatible: @@ -1547,34 +1618,26 @@ def __post_init__(self): float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig -@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) -def _float8_dynamic_activation_float8_weight_transform( - module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig -): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - +def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype granularity = config.granularity mm_config = config.mm_config - weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) if not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked - return module + return weight if isinstance(weight_granularity, PerRow): assert weight.dtype == torch.bfloat16, ( "PerRow quantization only works for bfloat16 precision input weight" ) - block_size = get_block_size(weight.shape, weight_granularity) + block_size = get_block_size(weight.shape[-2:], weight_granularity) + if weight.dim() == 3: + block_size = tuple([1] + list(block_size)) quantized_weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, @@ -1592,10 +1655,39 @@ def _float8_dynamic_activation_float8_weight_transform( quantized_weight = to_linear_activation_quantized( quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs ) + return quantized_weight - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + +@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) +def _float8_dynamic_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig +): + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( + weight, config + ) + new_weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + "applying float8 dynamic activation quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index b6fac49ae9..339d46be35 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -47,5 +47,6 @@ def _transform( @functools.wraps(config_type) def decorator(func): _QUANTIZE_CONFIG_HANDLER[config_type] = func + return func # needed to make the functions usable externally return decorator diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 0c30fba713..a9cad8060e 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -365,22 +365,23 @@ def get_groupwise_affine_qparams( def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16): guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) guard_dtype_size(zeros, "zeros", dtype=dtype) + dim = scales.dim() return ( torch.cat( [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), + scales.unsqueeze(-1), + zeros.unsqueeze(-1), ], - 2, + dim, ) - .transpose(0, 1) + .transpose(-3, -2) .contiguous() ) def unpack_tinygemm_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + assert scales_and_zeros.shape[-1] == 2 + return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1) def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False): diff --git a/torchao/utils.py b/torchao/utils.py index db269b4cb0..280da4e632 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -10,7 +10,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable, Tuple +from typing import Any, Callable import torch import torch.nn.utils.parametrize as parametrize @@ -170,7 +170,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): return measurement.mean * 1e6 -def find_multiple(n: int, *args: Tuple[int]) -> int: +def find_multiple(n: int, *args: int) -> int: k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] if n % k == 0: return n From f316a1cebfbc7b97cd8ea2ab059bc43264342fa7 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 7 May 2025 23:50:34 -0700 Subject: [PATCH 02/18] fixing CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_moe_quant.py | 19 +++++++++++++++++++ ...est_int8_dynamic_activation_intx_weight.py | 14 +++++++++++--- .../moe_quant/quantizable_moe_modules.py | 16 ++++++++++------ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index cbeb3ad308..a18b3519b1 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -169,6 +169,25 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): + config = Int8WeightOnlyConfig() + tensor_impl_class = PlainAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + device="cpu", + ) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") @parameterized.expand( diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 38db685f1c..06f4f3e8a6 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -646,8 +646,8 @@ def test_moe_quant_intx(self): from torchao.quantization.utils import compute_error with torch.device("cpu"): - model = MOEFeedForwardAOQuantizable(512, 256, 8, 2).to(torch.bfloat16) - x = torch.randn(1, 512, dtype=torch.bfloat16) + model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to(torch.bfloat16) + x = torch.randn(8, 512, dtype=torch.bfloat16) out = model(x).clone() @@ -661,7 +661,15 @@ def test_moe_quant_intx(self): out_q = model(x).clone() assert isinstance(model.experts.w1, FakeExtraDimTensor) - assert compute_error(out_q, out) > 30, "error bad accuracy but everything ran" + mod_c = torch.compile(model, mode="reduce-overhead") + + mod_c(x) + mod_c(x) + + out_qc = mod_c(x).clone() + + self.assertGreater(compute_error(out_q, out), 30) + self.assertGreater(compute_error(out_qc, out), 30) if __name__ == "__main__": diff --git a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py index ec41a8246d..5c938f62d6 100644 --- a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py +++ b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py @@ -120,12 +120,16 @@ def forward( ordered_token_indices = ( ordered_token_activations.div(top_k).floor().to(torch.int64) ) # [T] - num_tokens_per_expert = torch.histc( - expert_indices, - bins=self.num_experts + 1, - min=-1, - max=self.num_experts, - ) # [E+1] (added leading 0 so can be used for indexing) + if not expert_indices.is_cuda: # histc doesn't work on cpu for integers + num_tokens_per_expert = torch.bincount(expert_indices.view(-1)+1, minlength=self.num_experts+1) + else: + num_tokens_per_expert = torch.histc( + expert_indices, + bins=self.num_experts + 1, + min=-1, + max=self.num_experts, + ) # [E+1] (added leading 0 so can be used for indexing) + # num_tokens_per_expert = torch.bincount(expert_indices.view(-1)+1, minlength=self.num_experts+1) cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( torch.int64 ) # [E+1] From c2660fdf1fe7fc2e669b8fc7ca839a01d72d1354 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 00:06:36 -0700 Subject: [PATCH 03/18] fixing CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/mixtral-moe/README.md | 3 +++ torchao/_models/mixtral-moe/scripts/download.py | 2 +- torchao/_models/mixtral-moe/scripts/prepare.sh | 2 ++ .../tests/test_int8_dynamic_activation_intx_weight.py | 4 +++- .../prototype/moe_quant/quantizable_moe_modules.py | 6 ++++-- 5 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 torchao/_models/mixtral-moe/README.md create mode 100644 torchao/_models/mixtral-moe/scripts/prepare.sh diff --git a/torchao/_models/mixtral-moe/README.md b/torchao/_models/mixtral-moe/README.md new file mode 100644 index 0000000000..b68a3c9ff2 --- /dev/null +++ b/torchao/_models/mixtral-moe/README.md @@ -0,0 +1,3 @@ +This is the benchmarking setup primarily used for testing quantized moe. You can reproduce the above numbers by running + +`sh scripts/prepare.sh` diff --git a/torchao/_models/mixtral-moe/scripts/download.py b/torchao/_models/mixtral-moe/scripts/download.py index ec9077d3dc..8a451b001d 100644 --- a/torchao/_models/mixtral-moe/scripts/download.py +++ b/torchao/_models/mixtral-moe/scripts/download.py @@ -37,7 +37,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - parser.add_argument( "--repo_id", type=str, - default="checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1", + default="mistralai/Mixtral-8x7B-Instruct-v0.1", help="Repository ID to download from.", ) parser.add_argument( diff --git a/torchao/_models/mixtral-moe/scripts/prepare.sh b/torchao/_models/mixtral-moe/scripts/prepare.sh new file mode 100644 index 0000000000..72d59edc7e --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/prepare.sh @@ -0,0 +1,2 @@ +python scripts/download.py --repo_id mistralai/Mixtral-8x7B-Instruct-v0.1 +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/mistralai/Mixtral-8x7B-v0.1 diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 06f4f3e8a6..2a3ede6563 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -646,7 +646,9 @@ def test_moe_quant_intx(self): from torchao.quantization.utils import compute_error with torch.device("cpu"): - model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to(torch.bfloat16) + model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to( + torch.bfloat16 + ) x = torch.randn(8, 512, dtype=torch.bfloat16) out = model(x).clone() diff --git a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py index 5c938f62d6..bc9c9eb2ea 100644 --- a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py +++ b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py @@ -120,8 +120,10 @@ def forward( ordered_token_indices = ( ordered_token_activations.div(top_k).floor().to(torch.int64) ) # [T] - if not expert_indices.is_cuda: # histc doesn't work on cpu for integers - num_tokens_per_expert = torch.bincount(expert_indices.view(-1)+1, minlength=self.num_experts+1) + if not expert_indices.is_cuda: # histc doesn't work on cpu for integers + num_tokens_per_expert = torch.bincount( + expert_indices.view(-1) + 1, minlength=self.num_experts + 1 + ) else: num_tokens_per_expert = torch.histc( expert_indices, From 9e687ccf00b5186371b21ac44b7e47f6ba820364 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 01:07:23 -0700 Subject: [PATCH 04/18] fixing CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/mixtral-moe/README.md | 7 ++++++- torchao/_models/mixtral-moe/generate.py | 14 +++++++++----- torchao/_models/mixtral-moe/model.py | 14 +++++++++++--- torchao/_models/mixtral-moe/run.sh | 12 ++++++------ torchao/_models/mixtral-moe/scripts/prepare.sh | 2 +- .../prototype/moe_quant/quantizable_moe_modules.py | 1 - 6 files changed, 33 insertions(+), 17 deletions(-) diff --git a/torchao/_models/mixtral-moe/README.md b/torchao/_models/mixtral-moe/README.md index b68a3c9ff2..22c318aab9 100644 --- a/torchao/_models/mixtral-moe/README.md +++ b/torchao/_models/mixtral-moe/README.md @@ -1,3 +1,8 @@ -This is the benchmarking setup primarily used for testing quantized moe. You can reproduce the above numbers by running +## Mixtral-MoE + +This folder contains code and scripts for benchmarking the Mixtral-MoE model. +Running `sh scripts/prepare.sh` + +should download the model and `sh run.sh` will run teh benchmarks. diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index c9cf8e6f37..b21f5923cd 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -208,7 +208,6 @@ def main( assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) - print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") print(f"Using device={device}") precision = torch.bfloat16 is_chat = "chat" in str(checkpoint_path) @@ -220,10 +219,10 @@ def main( print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, device, precision) + model = _load_model(checkpoint_path, "cpu", precision) - device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") + t0 = time.time() tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) @@ -299,7 +298,12 @@ def main( if config is not None: quantize_(model, config, filter_fn=cond_ffn_filter) - torch.cuda.reset_peak_memory_stats() + print(f"Time to apply quantization to model: {time.time() - t0:.02f} seconds") + + model.to(device=device) + device_sync(device=device) + + print(f"C: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") if compile: # moe quant + compile causes repeated warnings @@ -382,7 +386,7 @@ def callback(x): if not interactive: pass - print(tokenizer.decode(y[0].tolist())) + # print(tokenizer.decode(y[0].tolist())) else: print() tokens_generated = y.size(-1) - prompt_length diff --git a/torchao/_models/mixtral-moe/model.py b/torchao/_models/mixtral-moe/model.py index 48dda5bef9..46a4ce79be 100644 --- a/torchao/_models/mixtral-moe/model.py +++ b/torchao/_models/mixtral-moe/model.py @@ -395,9 +395,17 @@ def forward( .to(torch.int64) ) # [T] - num_tokens_per_expert = torch.histc( - expert_indices, bins=self.num_experts + 1, min=-1, max=self.num_experts - ) # [E+1] (added leading 0 so can be used for indexing) + if not expert_indices.is_cuda: # histc doesn't work on cpu for integers + num_tokens_per_expert = torch.bincount( + expert_indices.view(-1) + 1, minlength=self.num_experts + 1 + ) + else: + num_tokens_per_expert = torch.histc( + expert_indices, + bins=self.num_experts + 1, + min=-1, + max=self.num_experts, + ) # [E+1] (added leading 0 so can be used for indexing) cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( torch.int64 ) # [E+1] diff --git a/torchao/_models/mixtral-moe/run.sh b/torchao/_models/mixtral-moe/run.sh index 692acfa1c4..d9e3a50405 100644 --- a/torchao/_models/mixtral-moe/run.sh +++ b/torchao/_models/mixtral-moe/run.sh @@ -1,5 +1,5 @@ export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 -export CHECKPOINT_PATH=~/checkpoints/ +export CHECKPOINT_PATH=checkpoints/ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --compile python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --compile @@ -16,11 +16,11 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --ba python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo-base --compile python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo-base --compile -# EXPERT CHOICE -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq --compile -# # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq --compile -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq-base --compile -# # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile +# # EXPERT CHOICE +# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq --compile +# # # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq --compile +# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq-base --compile +# # # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo --compile python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo --compile diff --git a/torchao/_models/mixtral-moe/scripts/prepare.sh b/torchao/_models/mixtral-moe/scripts/prepare.sh index 72d59edc7e..8ca60b165b 100644 --- a/torchao/_models/mixtral-moe/scripts/prepare.sh +++ b/torchao/_models/mixtral-moe/scripts/prepare.sh @@ -1,2 +1,2 @@ python scripts/download.py --repo_id mistralai/Mixtral-8x7B-Instruct-v0.1 -python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/mistralai/Mixtral-8x7B-v0.1 +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1 diff --git a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py index bc9c9eb2ea..516341a3a8 100644 --- a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py +++ b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py @@ -131,7 +131,6 @@ def forward( min=-1, max=self.num_experts, ) # [E+1] (added leading 0 so can be used for indexing) - # num_tokens_per_expert = torch.bincount(expert_indices.view(-1)+1, minlength=self.num_experts+1) cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( torch.int64 ) # [E+1] From 5a1402167c481c690cb1ffd3f229ef4021103165 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 01:09:07 -0700 Subject: [PATCH 05/18] lint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/mixtral-moe/generate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index b21f5923cd..c82d228886 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -298,8 +298,10 @@ def main( if config is not None: quantize_(model, config, filter_fn=cond_ffn_filter) - print(f"Time to apply quantization to model: {time.time() - t0:.02f} seconds") - + print( + f"Time to apply quantization to model: {time.time() - t0:.02f} seconds" + ) + model.to(device=device) device_sync(device=device) From c20bd253447ec9be1f2d109686fc7dcf123ef58a Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 01:11:36 -0700 Subject: [PATCH 06/18] remove test code Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/mixtral-moe/generate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index c82d228886..393b82a388 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -305,8 +305,6 @@ def main( model.to(device=device) device_sync(device=device) - print(f"C: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - if compile: # moe quant + compile causes repeated warnings import warnings From d73fbd82ba4913d6ce2fac6c1e3ed50f5aed642c Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 01:19:36 -0700 Subject: [PATCH 07/18] fixing exp test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .../tests/test_int8_dynamic_activation_intx_weight.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 2a3ede6563..a5c6437e8e 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -647,9 +647,9 @@ def test_moe_quant_intx(self): with torch.device("cpu"): model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to( - torch.bfloat16 + torch.float32 ) - x = torch.randn(8, 512, dtype=torch.bfloat16) + x = torch.randn(8, 512, dtype=torch.float32) out = model(x).clone() From 0cc4977c7d33acccabe5ec3908b3da53d5db375b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 01:29:28 -0700 Subject: [PATCH 08/18] fixing experimental test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- ...packed_linear_int8_dynamic_activation_intx_weight_layout.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py index 6caa0784d8..9046df799b 100644 --- a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -359,6 +359,9 @@ def _impl_2d_aten(input_tensor, weight_tensor): m, k = input_tensor.shape n, k_ = weight_tensor.shape + if m==0: # handling for empty input + return input_tensor + assert k_ == k group_size = weight_tensor.tensor_impl.get_layout().group_size packed_weight = weight_tensor.tensor_impl.packed_weight From 16bc60df26d2f795bf5ed11aee5a42629dbcad38 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 01:43:39 -0700 Subject: [PATCH 09/18] fixing experimental CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- ...cked_linear_int8_dynamic_activation_intx_weight_layout.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py index 9046df799b..dc7b073f32 100644 --- a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -359,8 +359,6 @@ def _impl_2d_aten(input_tensor, weight_tensor): m, k = input_tensor.shape n, k_ = weight_tensor.shape - if m==0: # handling for empty input - return input_tensor assert k_ == k group_size = weight_tensor.tensor_impl.get_layout().group_size @@ -369,6 +367,9 @@ def _impl_2d_aten(input_tensor, weight_tensor): input_tensor, packed_weight, group_size, k, n ) + if input_tensor.numel() == 0: + return input_tensor + target = weight_tensor.tensor_impl.get_layout().target if weight_tensor.tensor_impl.get_layout().has_bias: From 89ec74b26dd93f5b4a55d4c6846f7748b8af8565 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 06:32:05 -0700 Subject: [PATCH 10/18] fixing generate.py device stuff Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/mixtral-moe/generate.py | 14 +++++++------- torchao/dtypes/floatx/float8_layout.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index 393b82a388..9377ac1a3a 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -271,10 +271,10 @@ def main( config = Int4WeightOnlyConfig() elif "int4wo" in moe_quant: - config = MoEQuantConfig(Float8WeightOnlyConfig()) + config = MoEQuantConfig(Int4WeightOnlyConfig()) elif "fp8wo-base" in moe_quant: - config = Int4WeightOnlyConfig() + config = Float8WeightOnlyConfig() elif "fp8wo" in moe_quant: config = MoEQuantConfig(Float8WeightOnlyConfig()) @@ -297,7 +297,7 @@ def main( ) if config is not None: - quantize_(model, config, filter_fn=cond_ffn_filter) + quantize_(model, config, filter_fn=cond_ffn_filter, device=device) print( f"Time to apply quantization to model: {time.time() - t0:.02f} seconds" ) @@ -392,10 +392,10 @@ def callback(x): tokens_generated = y.size(-1) - prompt_length tokens_sec = tokens_generated / t aggregate_metrics["tokens_per_sec"].append(tokens_sec) - print( - f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" - ) - print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + # print( + # f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" + # ) + # print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") if i == 0 and device == "cuda" and memory_profile is not None: snapshot = torch.cuda.memory._snapshot() diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 29c4bf41b0..643e3f690c 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -159,7 +159,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) - elif func in [aten.select.int, func is aten.index.Tensor]: + elif func in [aten.select.int, aten.index.Tensor]: return return_and_correct_aliasing( func, args, From 8f6fdda028978a7e120bc40ae7975563b3bcba7b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 06:41:16 -0700 Subject: [PATCH 11/18] fixing tests that aren't skipping Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_moe_quant.py | 77 ++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index a18b3519b1..e177224c26 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -92,8 +92,6 @@ def _test_impl_moe_quant( self.assertGreaterEqual(compute_error(out_q, out), 10) self.assertGreaterEqual(compute_error(out_qc, out), 10) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") @parameterized.expand( [ ("single_token", 1, False), @@ -101,6 +99,11 @@ def _test_impl_moe_quant( ] ) def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + config = MoEQuantConfig(Int4WeightOnlyConfig()) tensor_impl_class = TensorCoreTiledAQTTensorImpl @@ -111,9 +114,6 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") @parameterized.expand( [ ("single_token", 1, True), @@ -121,6 +121,13 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): ] ) def test_int4wo_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + config = Int4WeightOnlyConfig() tensor_impl_class = TensorCoreTiledAQTTensorImpl @@ -131,8 +138,6 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") @parameterized.expand( [ ("single_token", 1, False), @@ -140,6 +145,11 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): ] ) def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -150,8 +160,6 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") @parameterized.expand( [ ("single_token", 1, True), @@ -159,6 +167,11 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): ] ) def test_int8wo_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + config = Int8WeightOnlyConfig() tensor_impl_class = PlainAQTTensorImpl @@ -169,7 +182,6 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") @parameterized.expand( [ ("single_token", 1, True), @@ -177,6 +189,9 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): ] ) def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + config = Int8WeightOnlyConfig() tensor_impl_class = PlainAQTTensorImpl @@ -188,14 +203,17 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): device="cpu", ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") @parameterized.expand( [ ("multiple_tokens", 32, False), ] ) def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) base_class = LinearActivationQuantizedTensor @@ -207,14 +225,17 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+") @parameterized.expand( [ ("multiple_tokens", 32, False), ] ) def test_int8dq_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + config = Int8DynamicActivationInt8WeightConfig() base_class = LinearActivationQuantizedTensor @@ -226,8 +247,6 @@ def test_int8dq_base(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") @parameterized.expand( [ ("single_token", 1, False), @@ -235,6 +254,11 @@ def test_int8dq_base(self, name, num_tokens, fullgraph): ] ) def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + config = MoEQuantConfig(Float8WeightOnlyConfig()) tensor_impl_class = Float8AQTTensorImpl @@ -245,8 +269,6 @@ def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") @parameterized.expand( [ ("single_token", 1, True), @@ -254,6 +276,11 @@ def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): ] ) def test_fp8wo_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + config = Float8WeightOnlyConfig() tensor_impl_class = Float8AQTTensorImpl @@ -264,8 +291,6 @@ def test_fp8wo_base(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") @parameterized.expand( [ ("single_token", 1, False), @@ -273,6 +298,11 @@ def test_fp8wo_base(self, name, num_tokens, fullgraph): ] ) def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) base_class = LinearActivationQuantizedTensor @@ -283,8 +313,6 @@ def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") @parameterized.expand( [ ("single_token", 1, True), @@ -292,6 +320,11 @@ def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): ] ) def test_fp8dq_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + config = Float8DynamicActivationFloat8WeightConfig() base_class = LinearActivationQuantizedTensor From 330f69efe875b3cf1525b58a5f902924c89de66d Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 06:41:39 -0700 Subject: [PATCH 12/18] ruff format Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_moe_quant.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index e177224c26..fbb2940b77 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -103,7 +103,7 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - + config = MoEQuantConfig(Int4WeightOnlyConfig()) tensor_impl_class = TensorCoreTiledAQTTensorImpl @@ -127,7 +127,7 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): self.skipTest("Requires CUDA capability >= 9.0") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - + config = Int4WeightOnlyConfig() tensor_impl_class = TensorCoreTiledAQTTensorImpl @@ -149,7 +149,7 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - + config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -171,7 +171,7 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - + config = Int8WeightOnlyConfig() tensor_impl_class = PlainAQTTensorImpl @@ -191,7 +191,7 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - + config = Int8WeightOnlyConfig() tensor_impl_class = PlainAQTTensorImpl @@ -213,7 +213,7 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) base_class = LinearActivationQuantizedTensor @@ -235,7 +235,7 @@ def test_int8dq_base(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - + config = Int8DynamicActivationInt8WeightConfig() base_class = LinearActivationQuantizedTensor @@ -258,7 +258,7 @@ def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - + config = MoEQuantConfig(Float8WeightOnlyConfig()) tensor_impl_class = Float8AQTTensorImpl @@ -280,7 +280,7 @@ def test_fp8wo_base(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - + config = Float8WeightOnlyConfig() tensor_impl_class = Float8AQTTensorImpl @@ -302,7 +302,7 @@ def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) base_class = LinearActivationQuantizedTensor @@ -324,7 +324,7 @@ def test_fp8dq_base(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - + config = Float8DynamicActivationFloat8WeightConfig() base_class = LinearActivationQuantizedTensor From e18e520ba215a5deb3a99826b2d9fd8acd120065 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 06:43:34 -0700 Subject: [PATCH 13/18] removing test code Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/mixtral-moe/generate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index 9377ac1a3a..2f0fe74e77 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -386,16 +386,16 @@ def callback(x): if not interactive: pass - # print(tokenizer.decode(y[0].tolist())) + print(tokenizer.decode(y[0].tolist())) else: print() tokens_generated = y.size(-1) - prompt_length tokens_sec = tokens_generated / t aggregate_metrics["tokens_per_sec"].append(tokens_sec) - # print( - # f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" - # ) - # print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" + ) + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") if i == 0 and device == "cuda" and memory_profile is not None: snapshot = torch.cuda.memory._snapshot() From 6e6f6eb7ceb6248b91621057f7dbfc671ea9a8c6 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 08:18:14 -0700 Subject: [PATCH 14/18] fixing CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 3 ++- test/quantization/test_moe_quant.py | 10 +++++----- torchao/_models/mixtral-moe/generate.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index c7428e72bd..5efb3ab6f1 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -2084,6 +2084,7 @@ def test_get_model_size_autoquant(self, device, dtype): ) mod(example_input) size2 = torchao.utils.get_model_size_in_bytes(mod) + print(size2, size) self.assertTrue(size2 < size) @parameterized.expand( @@ -2108,7 +2109,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype): size = torchao.utils.get_model_size_in_bytes(model) api(model) size2 = torchao.utils.get_model_size_in_bytes(model) - self.assertTrue(size2 < size) + self.assertGreaterEqual(size, size2) class TestBenchmarkModel(unittest.TestCase): diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index fbb2940b77..c3107120d6 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -25,7 +25,7 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_90 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_90, TORCH_VERSION_AT_LEAST_2_6 class TestMoEQuantCompile(unittest.TestCase): @@ -169,8 +169,8 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): def test_int8wo_base(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") + if not TORCH_VERSION_AT_LEAST_2_6: + self.skipTest("Test only enabled for 2.6+") config = Int8WeightOnlyConfig() tensor_impl_class = PlainAQTTensorImpl @@ -189,8 +189,8 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): ] ) def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") + if not TORCH_VERSION_AT_LEAST_2_6: + self.skipTest("Test only enabled for 2.6+") config = Int8WeightOnlyConfig() tensor_impl_class = PlainAQTTensorImpl diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index 2f0fe74e77..0a43f331e3 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -299,7 +299,7 @@ def main( if config is not None: quantize_(model, config, filter_fn=cond_ffn_filter, device=device) print( - f"Time to apply quantization to model: {time.time() - t0:.02f} seconds" + f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds" ) model.to(device=device) From c64118c1761484c98b1e43a0249e6aadbd926ac8 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 09:10:58 -0700 Subject: [PATCH 15/18] update API and remove branching on quant_api.py transform functions Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_moe_quant.py | 47 +++++-- torchao/_models/mixtral-moe/generate.py | 29 ++-- torchao/dtypes/affine_quantized_tensor_ops.py | 1 - torchao/dtypes/floatx/float8_layout.py | 35 ++++- ...est_int8_dynamic_activation_intx_weight.py | 3 +- .../prototype/moe_quant/README.md | 18 ++- .../prototype/moe_quant/llama4_quant.py | 4 +- .../quantization/prototype/moe_quant/utils.py | 64 +++++++-- torchao/quantization/quant_api.py | 130 ++++++------------ 9 files changed, 192 insertions(+), 139 deletions(-) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index c3107120d6..842468a769 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -12,6 +12,7 @@ from torchao.quantization.prototype.moe_quant.utils import ( FakeExtraDimTensor, MoEQuantConfig, + UseFakeExtraDimTensor, cond_ffn_filter, ) from torchao.quantization.quant_api import ( @@ -25,7 +26,11 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_90, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_90, +) class TestMoEQuantCompile(unittest.TestCase): @@ -61,7 +66,10 @@ def _test_impl_moe_quant( quantize_(model, config, cond_ffn_filter) - if isinstance(config, MoEQuantConfig): + if ( + isinstance(config, MoEQuantConfig) + and config.use_fake_extra_dim_tensor == UseFakeExtraDimTensor.TRUE + ): self.assertIsInstance(model.experts.w1, FakeExtraDimTensor) if base_class is not None: self.assertIsInstance(model.experts.w1.head_tensor, base_class) @@ -104,7 +112,9 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - config = MoEQuantConfig(Int4WeightOnlyConfig()) + config = MoEQuantConfig( + Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + ) tensor_impl_class = TensorCoreTiledAQTTensorImpl self._test_impl_moe_quant( @@ -128,7 +138,7 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - config = Int4WeightOnlyConfig() + config = MoEQuantConfig(Int4WeightOnlyConfig()) tensor_impl_class = TensorCoreTiledAQTTensorImpl self._test_impl_moe_quant( @@ -150,7 +160,9 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - config = MoEQuantConfig(Int8WeightOnlyConfig()) + config = MoEQuantConfig( + Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + ) tensor_impl_class = PlainAQTTensorImpl self._test_impl_moe_quant( @@ -172,7 +184,7 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): if not TORCH_VERSION_AT_LEAST_2_6: self.skipTest("Test only enabled for 2.6+") - config = Int8WeightOnlyConfig() + config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl self._test_impl_moe_quant( @@ -192,7 +204,7 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): if not TORCH_VERSION_AT_LEAST_2_6: self.skipTest("Test only enabled for 2.6+") - config = Int8WeightOnlyConfig() + config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl self._test_impl_moe_quant( @@ -214,7 +226,10 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + config = MoEQuantConfig( + Int8DynamicActivationInt8WeightConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) base_class = LinearActivationQuantizedTensor self._test_impl_moe_quant( @@ -236,7 +251,7 @@ def test_int8dq_base(self, name, num_tokens, fullgraph): if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") - config = Int8DynamicActivationInt8WeightConfig() + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) base_class = LinearActivationQuantizedTensor self._test_impl_moe_quant( @@ -259,7 +274,10 @@ def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - config = MoEQuantConfig(Float8WeightOnlyConfig()) + config = MoEQuantConfig( + Float8WeightOnlyConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) tensor_impl_class = Float8AQTTensorImpl self._test_impl_moe_quant( @@ -281,7 +299,7 @@ def test_fp8wo_base(self, name, num_tokens, fullgraph): if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - config = Float8WeightOnlyConfig() + config = MoEQuantConfig(Float8WeightOnlyConfig()) tensor_impl_class = Float8AQTTensorImpl self._test_impl_moe_quant( @@ -303,7 +321,10 @@ def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) + config = MoEQuantConfig( + Float8DynamicActivationFloat8WeightConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) base_class = LinearActivationQuantizedTensor self._test_impl_moe_quant( @@ -325,7 +346,7 @@ def test_fp8dq_base(self, name, num_tokens, fullgraph): if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - config = Float8DynamicActivationFloat8WeightConfig() + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) base_class = LinearActivationQuantizedTensor self._test_impl_moe_quant( diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index 0a43f331e3..6a4236d629 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -239,6 +239,7 @@ def main( from torchao.quantization.prototype.moe_quant.utils import ( MoEQuantConfig, cond_ffn_filter, + UseFakeExtraDimTensor ) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, @@ -256,40 +257,44 @@ def main( torch._dynamo.config.capture_dynamic_output_shape_ops = True config = None if "int8wo-base" in moe_quant: - config = Int8WeightOnlyConfig() + config = MoEQuantConfig(Int8WeightOnlyConfig()) elif "int8wo" in moe_quant: - config = MoEQuantConfig(Int8WeightOnlyConfig()) + config = MoEQuantConfig(Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) elif "int8dq-base" in moe_quant: - config = Int8DynamicActivationInt8WeightConfig() + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) elif "int8dq" in moe_quant: - config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) elif "int4wo-base" in moe_quant: - config = Int4WeightOnlyConfig() + config = MoEQuantConfig(Int4WeightOnlyConfig()) elif "int4wo" in moe_quant: - config = MoEQuantConfig(Int4WeightOnlyConfig()) + config = MoEQuantConfig(Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) elif "fp8wo-base" in moe_quant: - config = Float8WeightOnlyConfig() + config = MoEQuantConfig(Float8WeightOnlyConfig()) elif "fp8wo" in moe_quant: - config = MoEQuantConfig(Float8WeightOnlyConfig()) + config = MoEQuantConfig(Float8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) elif "fp8dq-base" in moe_quant: - config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) elif "fp8dq" in moe_quant: config = MoEQuantConfig( - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, ) elif "intxdq" in moe_quant: - config = Int8DynamicActivationIntxWeightConfig( - layout=PackedLinearInt8DynamicActivationIntxWeightLayout() + config = MoEQuantConfig( + Int8DynamicActivationIntxWeightConfig( + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + ), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, ) else: assert config is not None, ( diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 0283e67760..1d70f5c7f3 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -504,7 +504,6 @@ def _(func, types, args, kwargs): assert len(indices) == 1, ( f"op {func} currently only implemented for single dimensional indexing but got indices: {indices}" ) - new_tensor_impl = aten.index.Tensor(self.tensor_impl, indices) shape = tuple([indices[0].numel(), *self.shape[1:]]) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 643e3f690c..5a23eca66a 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -55,6 +55,7 @@ class Float8Layout(Layout): mm_config: Optional[Float8MMConfig] = None +_fallback_warning_shown = False @register_layout(Float8Layout) class Float8AQTTensorImpl(AQTTensorImpl): @@ -100,12 +101,34 @@ def __init__( def _apply_fn_to_data(self, fn): """Applys a fn to all tensor components stored on this class""" - return self.__class__( - fn(self.float8_data), - fn(self.scale), - self.transposed, - self._layout, - ) + global _fallback_warning_shown + + try: + return self.__class__( + fn(self.float8_data), + fn(self.scale), + self.transposed, + self._layout, + ) + except RuntimeError as e: + if '"index_cuda" not implemented for ' in str(e): + if not _fallback_warning_shown: + import warnings + warnings.warn( + f"When trying to index Float8AQTTensorImpl, got known error {e}, will use slower fallback but " + + "note: You can torch.compile the model to avoid this problem.", + UserWarning + ) + _fallback_warning_shown = True + + return self.__class__( # do indexing in bfloat16 then convert back + fn(self.float8_data.to(torch.bfloat16)).to(self.float8_data.dtype), + fn(self.scale), + self.transposed, + self._layout, + ) + else: + raise e def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index a5c6437e8e..b1667c0dec 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -637,6 +637,7 @@ def test_moe_quant_intx(self): FakeExtraDimTensor, MoEQuantConfig, cond_ffn_filter, + UseFakeExtraDimTensor, ) from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, @@ -656,7 +657,7 @@ def test_moe_quant_intx(self): base_config = Int8DynamicActivationIntxWeightConfig( layout=PackedLinearInt8DynamicActivationIntxWeightLayout() ) - moe_config = MoEQuantConfig(base_config) + moe_config = MoEQuantConfig(base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) quantize_(model, moe_config, cond_ffn_filter) diff --git a/torchao/quantization/prototype/moe_quant/README.md b/torchao/quantization/prototype/moe_quant/README.md index c80fb679a1..d774fae8fd 100644 --- a/torchao/quantization/prototype/moe_quant/README.md +++ b/torchao/quantization/prototype/moe_quant/README.md @@ -10,10 +10,10 @@ The API for moe quantization is very similar to linear quantization, given a moe ```python -from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter +from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, from torchao.quantization.quant_api import quantize_, Int8WeightOnlyConfig -quantize_(model, Int8WeightOnlyConfig(), filter_fn=cond_ffn_filter) +quantize_(model, MoEQuantConfig(Int8WeightOnlyConfig()), filter_fn=cond_ffn_filter) model=torch.compile(model, mode="reduce-overhead") # you can also use fullgraph=True for single token inference ``` @@ -23,20 +23,26 @@ This api is the same as for normal linear quantization but with a specific filte ## Alternative Quantization API -To make the above api work, each tensor subclass had to be edited to work as 3D tensors. However the only ops we actually need to support are a few indexing and slicing ops on the 0th dimension, the majority of the work was removing hard coded assumptions about the tensor dimensionality. This means its possible to instead create a new tensor subclass that pretends to be a 3D tensor by storing a series of 2D tensors and simulating the slicing and indexing ops until eventually just returning the singular desired 2D quantized tensor subclass. This can be achieved using the alternative api as follows: +To make the above api work, each tensor subclass had to be edited to work as 3D tensors. However the only ops we actually need to support are a few indexing and slicing ops on the 0th dimension, the majority of the work was removing hard coded assumptions about the tensor dimensionality. This means its possible to instead create a new tensor subclass that pretends to be a 3D tensor by storing a series of 2D tensors and simulating the slicing and indexing ops until eventually just returning the singular desired 2D quantized tensor subclass. This can be achieved using the alternative api by changing the fake_extra_dim_tensor flag of the MoEQuantConfig: ```python -from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig +from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig, UseFakeExtraDimTensor from torchao.quantization.quant_api import quantize_, Int8DynamicActivationIntxWeightConfig -config = MoEQuantConfig(Int8DynamicActivationIntxWeightConfig()) +config = MoEQuantConfig( + Int8DynamicActivationIntxWeightConfig(), + # this is the only difference from the above api + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, +) quantize_(model, , filter_fn=cond_ffn_filter) model=torch.compile(model, mode="reduce-overhead") ``` -While this approach turns out to not be especially performant, it does allow for comparable memory characteristics, allowing models that wouldn't fit on a single node/gpu to actually run. It is flexible enough however to work with all of the existing linear quantization techniques that make use of quantized tensor subclasses without any changes being made to those classes. It is compilable though even single token inference doesn't work with fullgraph compilation. +It should also be noted that the default value for use_fake_extra_dim_tensor is AS_FALLBACK which means that it will try to use the base method but if not, will use the more general but less performant fake_extra_dim_tensor method. + +While this approach turns out to not be especially performant, it does allow for slightly better memory characteristics since all the tensors are held seperately and aren't actually modified or indexed. It is flexible enough to work with all of the existing linear quantization techniques that make use of quantized tensor subclasses without any changes being made to those classes. It is compilable though neither single token nor multi token inference works with fullgraph compilation. ## Model API diff --git a/torchao/quantization/prototype/moe_quant/llama4_quant.py b/torchao/quantization/prototype/moe_quant/llama4_quant.py index d68c0d28c1..48f99ad631 100644 --- a/torchao/quantization/prototype/moe_quant/llama4_quant.py +++ b/torchao/quantization/prototype/moe_quant/llama4_quant.py @@ -70,9 +70,9 @@ def convert_fn(module): model = model from torchao.quantization import Int4WeightOnlyConfig, quantize_ -from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter +from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig -quantize_(model, Int4WeightOnlyConfig(), cond_ffn_filter, device="cuda") +quantize_(model, MoEQuantConfig(Int4WeightOnlyConfig()), cond_ffn_filter, device="cuda") model.cuda() diff --git a/torchao/quantization/prototype/moe_quant/utils.py b/torchao/quantization/prototype/moe_quant/utils.py index 22e786bc3c..4acb11fa4e 100644 --- a/torchao/quantization/prototype/moe_quant/utils.py +++ b/torchao/quantization/prototype/moe_quant/utils.py @@ -13,6 +13,8 @@ register_quantize_module_handler, ) from torchao.utils import fill_defaults +from enum import Enum, auto +from torchao.quantization.quant_api import _QUANTIZE_CONFIG_HANDLER class DummyModule(torch.nn.Module): @@ -211,6 +213,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise e +class UseFakeExtraDimTensor(Enum): + """Enum that indicate whether to use FakeExtraDimTensor + """ + TRUE = auto() + FALSE = auto() + AS_FALLBACK = auto() + @dataclass class MoEQuantConfig(AOBaseConfig): @@ -220,6 +229,48 @@ class MoEQuantConfig(AOBaseConfig): """ base_config: AOBaseConfig + use_fake_extra_dim_tensor: UseFakeExtraDimTensor = UseFakeExtraDimTensor.AS_FALLBACK + set_inductor_config: bool=True + + +# Module-level flag to track if we've already printed the error +_moe_quant_tensor_has_printed_error = False + +def _moe_quant_tensor(weight, config): + def _moe_quant_tensor_base(weight, config): + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + dummy_mod = DummyModule(weight) + quant_mod = base_config_handler(dummy_mod, config.base_config) + return quant_mod.weight + + def _moe_quant_tensor_fake_extra_dim_tensor(weight, config): + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + # break 3D tensor + tensors = [weight[i] for i in range(weight.shape[0])] + # put tensors into modules since the handlers target modules not tensors + dummy_modules = [DummyModule(tensor) for tensor in tensors] + # apply handler to each module + quant_mods = list(map(lambda x: base_config_handler(x, config.base_config), dummy_modules)) + # pack quantized subclasses into FakeExtraDimTensor + quant_weight = FakeExtraDimTensor([mod.weight for mod in quant_mods]) + return quant_weight + + global _moe_quant_tensor_has_printed_error + + use_fake = config.use_fake_extra_dim_tensor + if use_fake == UseFakeExtraDimTensor.FALSE: + return _moe_quant_tensor_base(weight, config) + elif use_fake == UseFakeExtraDimTensor.AS_FALLBACK: + try: + return _moe_quant_tensor_base(weight, config) + except Exception as e: + if not _moe_quant_tensor_has_printed_error: + print(f"tried to do moe_quant but got error: {e}") + _moe_quant_tensor_has_printed_error = True + return _moe_quant_tensor_fake_extra_dim_tensor(weight, config) + else: # This handles UseFakeExtraDimTensor.TRUE + return _moe_quant_tensor_fake_extra_dim_tensor(weight, config) + @register_quantize_module_handler(MoEQuantConfig) @@ -229,24 +280,15 @@ def moe_quant_fn(module, config: MoEQuantConfig): warnings.simplefilter("ignore", lineno=84) warnings.simplefilter("ignore", lineno=105) assert "ConditionalFeedForwardAOQuantizable" in str(type(module)) - from torchao.quantization.quant_api import _QUANTIZE_CONFIG_HANDLER for weight_attr in ["w1", "w2", "w3"]: param = getattr(module, weight_attr) + assert param.dim() == 3, f"when applying moe_quant to {module} expected 3D tensor for {weight_attr} but got {param.dim()}" assert isinstance(config.base_config, AOBaseConfig), ( f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}" + "this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())" ) - handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] - - # break 3D tensor - tensors = [param[i] for i in range(param.shape[0])] - # put tensors into modules since the handlers target modules not tensors - dummy_modules = [DummyModule(tensor) for tensor in tensors] - # apply handler to each module - out_mods = list(map(lambda x: handler(x, config.base_config), dummy_modules)) - # pack quantized subclasses into FakeExtraDimTensor - new_param = FakeExtraDimTensor([mod.weight for mod in out_mods]) + new_param = _moe_quant_tensor(param, config) new_param = torch.nn.Parameter(new_param, requires_grad=False) setattr(module, weight_attr, new_param) del param diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 471af07858..c20c37a194 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1130,22 +1130,14 @@ def _int4_weight_only_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if "ConditionalFeedForwardAOQuantizable" in str(type(module)): - for weight_attr in ["w1", "w2", "w3"]: - weight = getattr(module, weight_attr) - new_weight = _int4_weight_only_quantize_tensor(weight, config) - new_weight = torch.nn.Parameter(new_weight, requires_grad=False) - setattr(module, weight_attr, new_weight) - return module - else: - assert hasattr(module, "weight"), ( - "applying int8 weight only quant requires module to have weight attribute" - + " but {module} does not have one" - ) - new_weight = _int4_weight_only_quantize_tensor(module.weight, config) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int4_weight_only_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass @@ -1187,22 +1179,14 @@ def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyC if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if "ConditionalFeedForwardAOQuantizable" in str(type(module)): - for weight_attr in ["w1", "w2", "w3"]: - weight = getattr(module, weight_attr) - new_weight = _int8_weight_only_quantize_tensor(weight, config) - new_weight = torch.nn.Parameter(new_weight, requires_grad=False) - setattr(module, weight_attr, new_weight) - return module - else: - assert hasattr(module, "weight"), ( - "applying int8 weight only quant requires module to have weight attribute" - + " but {module} does not have one" - ) - new_weight = _int8_weight_only_quantize_tensor(module.weight, config) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int8_weight_only_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: @@ -1372,25 +1356,15 @@ def _int8_dynamic_activation_int8_weight_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if "ConditionalFeedForwardAOQuantizable" in str(type(module)): - for weight_attr in ["w1", "w2", "w3"]: - weight = getattr(module, weight_attr) - new_weight = _int8_dynamic_activation_int8_weight_quantize_tensor( - weight, config - ) - new_weight = torch.nn.Parameter(new_weight, requires_grad=False) - setattr(module, weight_attr, new_weight) - return module - else: - assert hasattr(module, "weight"), ( - "applying int8 dynamic activation int8 weight quant requires module to have weight attribute" - + "but {module} does not have one" - ) - new_weight = _int8_dynamic_activation_int8_weight_quantize_tensor( - module.weight, config - ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + assert hasattr(module, "weight"), ( + "applying int8 dynamic activation int8 weight quant requires module to have weight attribute" + + "but {module} does not have one" + ) + new_weight = _int8_dynamic_activation_int8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -1449,23 +1423,15 @@ def _float8_weight_only_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if "ConditionalFeedForwardAOQuantizable" in str(type(module)): - for weight_attr in ["w1", "w2", "w3"]: - weight = getattr(module, weight_attr) - new_weight = _float8_weight_only_quant_tensor(weight, config) - new_weight = torch.nn.Parameter(new_weight, requires_grad=False) - setattr(module, weight_attr, new_weight) - return module - else: - assert hasattr(module, "weight"), ( - "applying int8 weight only quant requires module to have weight attribute" - + " but {module} does not have one" - ) - new_weight = _float8_weight_only_quant_tensor(module.weight, config) + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _float8_weight_only_quant_tensor(module.weight, config) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module _fp8_granularities = Union[PerTensor, PerRow] @@ -1668,26 +1634,16 @@ def _float8_dynamic_activation_float8_weight_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if "ConditionalFeedForwardAOQuantizable" in str(type(module)): - for weight_attr in ["w1", "w2", "w3"]: - weight = getattr(module, weight_attr) - quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( - weight, config - ) - new_weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - setattr(module, weight_attr, new_weight) - return module - else: - assert hasattr(module, "weight"), ( - "applying float8 dynamic activation quant requires module to have weight attribute" - + f"but {module} does not have one" - ) - quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( - module.weight, config - ) - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + assert hasattr(module, "weight"), ( + "applying float8 dynamic activation quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass From 6684b17853fe5ef691f19e5390e15099e5c09268 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 11:01:24 -0700 Subject: [PATCH 16/18] ruff format Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/mixtral-moe/generate.py | 26 ++++++++++++++----- torchao/dtypes/floatx/float8_layout.py | 11 +++++--- ...est_int8_dynamic_activation_intx_weight.py | 6 +++-- .../prototype/moe_quant/llama4_quant.py | 5 +++- .../quantization/prototype/moe_quant/utils.py | 23 +++++++++------- 5 files changed, 49 insertions(+), 22 deletions(-) diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index 6a4236d629..0dcd86e74f 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -238,8 +238,8 @@ def main( from torchao.quantization.prototype.moe_quant.utils import ( MoEQuantConfig, + UseFakeExtraDimTensor, cond_ffn_filter, - UseFakeExtraDimTensor ) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, @@ -260,28 +260,42 @@ def main( config = MoEQuantConfig(Int8WeightOnlyConfig()) elif "int8wo" in moe_quant: - config = MoEQuantConfig(Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) + config = MoEQuantConfig( + Int8WeightOnlyConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) elif "int8dq-base" in moe_quant: config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) elif "int8dq" in moe_quant: - config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) + config = MoEQuantConfig( + Int8DynamicActivationInt8WeightConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) elif "int4wo-base" in moe_quant: config = MoEQuantConfig(Int4WeightOnlyConfig()) elif "int4wo" in moe_quant: - config = MoEQuantConfig(Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) + config = MoEQuantConfig( + Int4WeightOnlyConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) elif "fp8wo-base" in moe_quant: config = MoEQuantConfig(Float8WeightOnlyConfig()) elif "fp8wo" in moe_quant: - config = MoEQuantConfig(Float8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) + config = MoEQuantConfig( + Float8WeightOnlyConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) elif "fp8dq-base" in moe_quant: - config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + config = MoEQuantConfig( + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + ) elif "fp8dq" in moe_quant: config = MoEQuantConfig( diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 5a23eca66a..5914f00102 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -55,8 +55,10 @@ class Float8Layout(Layout): mm_config: Optional[Float8MMConfig] = None + _fallback_warning_shown = False + @register_layout(Float8Layout) class Float8AQTTensorImpl(AQTTensorImpl): """ @@ -102,7 +104,7 @@ def __init__( def _apply_fn_to_data(self, fn): """Applys a fn to all tensor components stored on this class""" global _fallback_warning_shown - + try: return self.__class__( fn(self.float8_data), @@ -114,14 +116,15 @@ def _apply_fn_to_data(self, fn): if '"index_cuda" not implemented for ' in str(e): if not _fallback_warning_shown: import warnings + warnings.warn( f"When trying to index Float8AQTTensorImpl, got known error {e}, will use slower fallback but " + "note: You can torch.compile the model to avoid this problem.", - UserWarning + UserWarning, ) _fallback_warning_shown = True - - return self.__class__( # do indexing in bfloat16 then convert back + + return self.__class__( # do indexing in bfloat16 then convert back fn(self.float8_data.to(torch.bfloat16)).to(self.float8_data.dtype), fn(self.scale), self.transposed, diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index b1667c0dec..d1236e9183 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -636,8 +636,8 @@ def test_moe_quant_intx(self): from torchao.quantization.prototype.moe_quant.utils import ( FakeExtraDimTensor, MoEQuantConfig, - cond_ffn_filter, UseFakeExtraDimTensor, + cond_ffn_filter, ) from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, @@ -657,7 +657,9 @@ def test_moe_quant_intx(self): base_config = Int8DynamicActivationIntxWeightConfig( layout=PackedLinearInt8DynamicActivationIntxWeightLayout() ) - moe_config = MoEQuantConfig(base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE) + moe_config = MoEQuantConfig( + base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + ) quantize_(model, moe_config, cond_ffn_filter) diff --git a/torchao/quantization/prototype/moe_quant/llama4_quant.py b/torchao/quantization/prototype/moe_quant/llama4_quant.py index 48f99ad631..67ad2ab464 100644 --- a/torchao/quantization/prototype/moe_quant/llama4_quant.py +++ b/torchao/quantization/prototype/moe_quant/llama4_quant.py @@ -70,7 +70,10 @@ def convert_fn(module): model = model from torchao.quantization import Int4WeightOnlyConfig, quantize_ -from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig +from torchao.quantization.prototype.moe_quant.utils import ( + MoEQuantConfig, + cond_ffn_filter, +) quantize_(model, MoEQuantConfig(Int4WeightOnlyConfig()), cond_ffn_filter, device="cuda") diff --git a/torchao/quantization/prototype/moe_quant/utils.py b/torchao/quantization/prototype/moe_quant/utils.py index 4acb11fa4e..16fa8c8d33 100644 --- a/torchao/quantization/prototype/moe_quant/utils.py +++ b/torchao/quantization/prototype/moe_quant/utils.py @@ -5,16 +5,16 @@ aten = torch.ops.aten +from enum import Enum, auto from typing import List, Optional, Tuple, Union from torchao.quantization.quant_api import ( + _QUANTIZE_CONFIG_HANDLER, AOBaseConfig, dataclass, register_quantize_module_handler, ) from torchao.utils import fill_defaults -from enum import Enum, auto -from torchao.quantization.quant_api import _QUANTIZE_CONFIG_HANDLER class DummyModule(torch.nn.Module): @@ -213,9 +213,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise e + class UseFakeExtraDimTensor(Enum): - """Enum that indicate whether to use FakeExtraDimTensor - """ + """Enum that indicate whether to use FakeExtraDimTensor""" + TRUE = auto() FALSE = auto() AS_FALLBACK = auto() @@ -230,12 +231,13 @@ class MoEQuantConfig(AOBaseConfig): base_config: AOBaseConfig use_fake_extra_dim_tensor: UseFakeExtraDimTensor = UseFakeExtraDimTensor.AS_FALLBACK - set_inductor_config: bool=True + set_inductor_config: bool = True # Module-level flag to track if we've already printed the error _moe_quant_tensor_has_printed_error = False + def _moe_quant_tensor(weight, config): def _moe_quant_tensor_base(weight, config): base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] @@ -250,13 +252,15 @@ def _moe_quant_tensor_fake_extra_dim_tensor(weight, config): # put tensors into modules since the handlers target modules not tensors dummy_modules = [DummyModule(tensor) for tensor in tensors] # apply handler to each module - quant_mods = list(map(lambda x: base_config_handler(x, config.base_config), dummy_modules)) + quant_mods = list( + map(lambda x: base_config_handler(x, config.base_config), dummy_modules) + ) # pack quantized subclasses into FakeExtraDimTensor quant_weight = FakeExtraDimTensor([mod.weight for mod in quant_mods]) return quant_weight global _moe_quant_tensor_has_printed_error - + use_fake = config.use_fake_extra_dim_tensor if use_fake == UseFakeExtraDimTensor.FALSE: return _moe_quant_tensor_base(weight, config) @@ -272,7 +276,6 @@ def _moe_quant_tensor_fake_extra_dim_tensor(weight, config): return _moe_quant_tensor_fake_extra_dim_tensor(weight, config) - @register_quantize_module_handler(MoEQuantConfig) def moe_quant_fn(module, config: MoEQuantConfig): import warnings @@ -283,7 +286,9 @@ def moe_quant_fn(module, config: MoEQuantConfig): for weight_attr in ["w1", "w2", "w3"]: param = getattr(module, weight_attr) - assert param.dim() == 3, f"when applying moe_quant to {module} expected 3D tensor for {weight_attr} but got {param.dim()}" + assert param.dim() == 3, ( + f"when applying moe_quant to {module} expected 3D tensor for {weight_attr} but got {param.dim()}" + ) assert isinstance(config.base_config, AOBaseConfig), ( f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}" + "this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())" From d34fdc588db02ee127ff6c4fdbbafa8b5a22a182 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 11:06:42 -0700 Subject: [PATCH 17/18] fix weird ci error Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5efb3ab6f1..bf22eff478 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -2084,8 +2084,7 @@ def test_get_model_size_autoquant(self, device, dtype): ) mod(example_input) size2 = torchao.utils.get_model_size_in_bytes(mod) - print(size2, size) - self.assertTrue(size2 < size) + self.assertGreaterEqual(size, size2) @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), From d927f06911bf337a2decaff1a6ca9304c5e17003 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 8 May 2025 11:55:06 -0700 Subject: [PATCH 18/18] remove change to test_integration.py Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index bf22eff478..c7428e72bd 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -2084,7 +2084,7 @@ def test_get_model_size_autoquant(self, device, dtype): ) mod(example_input) size2 = torchao.utils.get_model_size_in_bytes(mod) - self.assertGreaterEqual(size, size2) + self.assertTrue(size2 < size) @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), @@ -2108,7 +2108,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype): size = torchao.utils.get_model_size_in_bytes(model) api(model) size2 = torchao.utils.get_model_size_in_bytes(model) - self.assertGreaterEqual(size, size2) + self.assertTrue(size2 < size) class TestBenchmarkModel(unittest.TestCase):