Skip to content

v1.3.5 Quantization

Choose a tag to compare

@DefTruth DefTruth released this 30 Mar 08:13
· 115 commits to main since this release
a5022a5

Low-bits Quantization

Overview

Quantization is a powerful technique to reduce the memory footprint and computational cost of deep learning models by representing weights and activations with lower precision data types. Cache-DiT supports various quantization methods, including FP8, INT8, and INT4 quantization, to help users achieve faster inference and lower memory usage while maintaining acceptable model performance.

quantization type description devices
float8_per_row quantize weights and activations to float8 (dynamic quantization) with rowwise method. (recommended) >=sm89, Ada, Hopper or newer
float8_per_tensor quantize weights and activations to float8 (dynamic quantization) with tensorwise method. >=sm89, Ada, Hopper or newer
float8_per_block block-wise quantization weights and activations (dynamic quantization) to float8, which can provide better precision, activations's blocksize: (1, 128), weight's blocksize: (128, 128) >=sm89, Ada, Hopper or newer
float8_weight_only quantize only weights to float8, keep activations in full precision >=sm89, Ada, Hopper or newer
int8_per_row quantize weights and activations to int8 (dynamic quantization) with rowwise method. >=sm80, Ampere or newer
int8_per_tensor quantize weights and activations to int8 (dynamic quantization) with tensorwise method. >=sm80, Ampere or newer
int8_weight_only quantize only weights to int8, keep activations in full precision >=sm80, Ampere or newer
int4_weight_only quantize only weights to int4, keep activations in full precision >=sm90, Hopper or newer, TMA required

FP8 Quantization

Currently, TorchAo has been fully integrated into Cache-DiT as the backend for online quantization. You can implement model quantization by calling quantize or pass a QuantizeConfig to enable_cache API. (recommended)

For GPUs with low memory capacity, we recommend using float8_per_row or float8_per_block, as these methods cause almost no loss in precision. Supported quantization types including:

  • float8_per_row: quantize both weights and activations to float8 (dynamic quantization) with rowwise method.
  • float8_per_tensor: quantize both weights and activations to float8 (dynamic quantization) with tensorwise method.
  • float8_per_block: block-wise quantization weights and activations (dynamic quantization) to float8, which can provide better precision, activations's blocksize: (1, 128), weight's blocksize: (128, 128). NOT supported for distributed inference for now.
  • float8_weight_only: quantize only weights to float8, keep activations in full precision.

Here are some examples of how to use quantization with cache-dit. You can directly specify the quantization config in the enable_cache API.

import cache_dit
from cache_dit import DBCacheConfig, ParallelismConfig, QuantizeConfig

# quant_type: float8_per_row, float8_per_tensor, float8_per_block, float8_weight_only, 
# int8_per_row, int8_per_tensor, int8_weight_only, int4_weight_only, etc.
# Pass a QuantizeConfig to the `enable_cache` API.
cache_dit.enable_cache( 
    pipe, cache_config=DBCacheConfig(), # w/ default
    parallelism_config=ParallelismConfig(ulysses_size=2),
    quantize_config=QuantizeConfig(quant_type="float8_per_row"),
)

Users can also specify different quantization configs for different components. For example, quantize the transformer to float8_per_row and the text encoder to float8_weight_only.

import cache_dit
from cache_dit import DBCacheConfig, ParallelismConfig, QuantizeConfig

cache_dit.enable_cache( 
    pipe, cache_config=DBCacheConfig(), # w/ default
    parallelism_config=ParallelismConfig(ulysses_size=2),
    quantize_config=QuantizeConfig(
        components_to_quantize={
            "transformer": {
                "quant_type": "float8_per_row",
                "exclude_layers": ["embedder", "embed"],
            },
            "text_encoder": {
                "quant_type": "float8_weight_only",
                "exclude_layers": ["lm_head"],
            }
        }
    ),
)

Or, directly call the quantize API for more fine-grained control.

import cache_dit
from cache_dit import QuantizeConfig

cache_dit.quantize(
    pipe.transformer, 
    quantize_config=QuantizeConfig(quant_type="float8_per_row"),
)
cache_dit.quantize(
    pipe.text_encoder, 
    quantize_config=QuantizeConfig(quant_type="float8_weight_only"),
)

Please also enable torch.compile for better performance with quantization.

import cache_dit

cache_dit.set_compile_configs()
pipe.transformer = torch.compile(pipe.transformer)
pipe.text_encoder = torch.compile(pipe.text_encoder)

Users can set exclude_layers in QuantizeConfig to exclude some sensitive layers that are not robust to quantization, e.g., embedding layers. Layers that contain any of the keywords in the exclude_layers list will be excluded from quantization. For example:

import cache_dit
from cache_dit import DBCacheConfig, ParallelismConfig, QuantizeConfig

cache_dit.enable_cache( 
    pipe, cache_config=DBCacheConfig(), # w/ default
    parallelism_config=ParallelismConfig(ulysses_size=2),
    quantize_config=QuantizeConfig(
        quant_type="float8_per_row",
        exclude_layers=["embedder", "embed"],
    ),
)

By default, quant_type="float8_per_row" for better precision. Users can set it to "float8_per_tensor" to use per-tensor quantization for better performance on some hardware.

Regional Quantization

Cache-DiT also supports regional quantization, which allows users to quantize only the repeated blocks in a transformer. This can be useful for better balancing the precision and efficiency. Users can specify the blocks to be quantized via the regional_quantize and repeated_blocks arguments in QuantizeConfig. For example, to quantize repeated blocks of the Flux2's transformer:

import cache_dit
from cache_dit import DBCacheConfig, ParallelismConfig, QuantizeConfig

cache_dit.enable_cache( 
    pipe, cache_config=DBCacheConfig(), # w/ default
    parallelism_config=ParallelismConfig(ulysses_size=2),
    quantize_config=QuantizeConfig(
        quant_type="float8_per_row",
        # Default (True), only quantize the repeated blocks in transformer if the repeated_blocks is 
        # specified. If set to False, the whole transformer will be quantized.
        regional_quantize=True, 
        # Specify the block names for the transformer, cache-dit will automatically find the repeated 
        # blocks and quantize it inplace. The block names can be found in the model architecture, e.g., 
        # for FLUX.2, the block name is "Flux2TransformerBlock" and "Flux2SingleTransformerBlock".
        repeated_blocks=['Flux2TransformerBlock', 'Flux2SingleTransformerBlock'],
        # repeated_blocks will be detected automatically from diffusers' transformer class, namely:
        # default repeated_blocks = transformer._repeated_blocks if exists, else None (quantize 
        # the whole transformer.
    ),
)

FP8 Per-Tensor Fallback

The per_tensor_fallback option in Cache-DiT's quantization configuration allows users to enable a fallback mechanism for layers that do not support float8 per-row or per-block quantization. This is particularly useful in scenarios where tensor parallelism is applied, and certain layers (e.g., those applied with RowwiseParallel) may encounter memory layout mismatch errors when quantized to float8 per-row.

When per_tensor_fallback is set to True, if a layer cannot be quantized to float8 per-row or per-block, it will automatically fall back to float8 per-tensor quantization instead of raising an error. This ensures that the quantization process can continue smoothly without interruption, while still providing the benefits of reduced precision for supported layers.

To enable this feature, simply set the per_tensor_fallback flag to True (default) in the QuantizeConfig when calling the enable_cache API. Only support for float8 quantization for now. For example:

import cache_dit
from cache_dit import DBCacheConfig, ParallelismConfig, QuantizeConfig

cache_dit.enable_cache( 
    pipe, cache_config=DBCacheConfig(), # w/ default
    parallelism_config=ParallelismConfig(tp_size=2),
    quantize_config=QuantizeConfig(
        quant_type="float8_per_row",
        # Must be True to enable fp8 per-tensor fallback.
        regional_quantize=True, # default, True.
        repeated_blocks=['Flux2TransformerBlock', 'Flux2SingleTransformerBlock'],
        # Enable fallback to float8 per-tensor quantization, default to True
        # for better compatibility for layers that do not support float8 per-row 
        # quantization, e.g., layers with RowwiseParallel applied in tensor parallelism.
        per_tensor_fallback=True, 
    ),
)

For examples, without fp8 per-tensor fallback, the cache-dit will auto skip the layers that do not support float8 per-row quantization, and raise warning for those layers. The performance will be worse due to less layers being quantized. (quantize 88 layers, skip 56 layers)

# w/o fp8 per-tensor fallback, quantize 88 layers, skip 56 layers, performance downgrade.
torchrun --nproc_per_node=2 -m cache_dit.generate flux2_klein_9b_kv_edit \
   --parallel tp --compile --float8-per-row --q-verbose \
   --disable-per-tensor-fallback
-----------------------------------------------------------------------------------
Quantized        Region: ['Flux2TransformerBlock', 'Flux2SingleTransformerBlock']  |
Quantized Linear Layers: 88    float8_per_row     56 (skipped)                     |
Quantized Linear Layers: 88    (total)                                             |
Skipped   Linear Layers: 56    (total)                                             |
Linear           Layers: 144   (total)                                             |
-----------------------------------------------------------------------------------
------------------------------------------------------------------------------------
float8_per_row, skip: attn.to_out.0        : pattern<RowwiseParallel>: 8    layers  |
float8_per_row, skip: attn.to_add_out      : pattern<RowwiseParallel>: 8    layers  |
float8_per_row, skip: ff.linear_out        : pattern<RowwiseParallel>: 8    layers  |
float8_per_row, skip: ff_context.linear_out: pattern<RowwiseParallel>: 8    layers  |
float8_per_row, skip: attn.to_out          : pattern<RowwiseParallel>: 24   layers  |
------------------------------------------------------------------------------------

With fp8 per-tensor fallback enabled, those layers that do not support float8 per-row quantization will be quantized to float8 per-tensor instead, and the performance will be better due to more layers being quantized. (quantize 144 layers, skip 0 layer)

# w/ fp8 per-tensor fallback enabled, quantize 144 layers, skip 0 layer, better performance.
torchrun --nproc_per_node=2 -m cache_dit.generate flux2_klein_9b_kv_edit \
   --parallel tp --compile --float8-per-row --q-verbose  
# Default, enabled fp8 per-tensor fallback
-----------------------------------------------------------------------------------
Quantized        Region: ['Flux2TransformerBlock', 'Flux2SingleTransformerBlock']  |
Quantized Linear Layers: 88    float8_per_row     0 (skipped)                      |
Quantized Linear Layers: 56    float8_per_tensor  0 (skipped)                      |
Quantized Linear Layers: 144   (total)                                             |
Skipped   Linear Layers: 0     (total)                                             |
Linear           Layers: 144   (total)                                             |
-----------------------------------------------------------------------------------

(Hybrid) Precision Plan

The precision_plan option in QuantizeConfig allows users to specify different quantization types for matched layer-name patterns. It is useful when you want better control of the accuracy and performance trade-off for attention sub-layers (for example, keep to_k/to_v in float8_per_row while using float8_per_tensor for to_q/to_out). Please note:

  • Layers not matched by precision_plan continue to use the base quant_type.
  • precision_plan is only valid when regional_quantize=True. If regional quantization is disabled, precision plan will be ignored.
  • precision_plan is compatible with per_tensor_fallback. If a selected plan type is not supported by a specific layer/hardware path (case: rowwise tensor parallel is used and the basic quantize type is float8_per_row), fallback logic still works automatically when enabled.

For example: (FLUX.2-Klein-9b-kv)

import cache_dit
from cache_dit import DBCacheConfig, ParallelismConfig, QuantizeConfig

cache_dit.enable_cache(
    pipe,
    cache_config=DBCacheConfig(),
    quantize_config=QuantizeConfig(
       # Default type for unmatched layers in transformer.
        quant_type="float8_per_row",
        regional_quantize=True,
        repeated_blocks=['Flux2TransformerBlock', 'Flux2SingleTransformerBlock'],
        per_tensor_fallback=True,
        precision_plan={
            "attn.to_q": "float8_per_tensor",  # match: **attn.to_q**, best performance. 
            "attn.to_k": "float8_weight_only", # match: **attn.to_k**, best precision.
            "attn.to_v": "float8_per_block",   # match: **attn.to_v**, better precision.
            "attn.to_out": "float8_per_row",   # match: **attn.to_out**, better precision.
        },
    ),
)
# python3 -m cache_dit.generate flux2_klein_9b_kv_edit --config quantize_plan.yaml --compile

Then, the output summary will show the quantization type for each layer, and users can verify the quantization plan is applied correctly.

-----------------------------------------------------------------------------------
Quantized        Region: ['Flux2TransformerBlock', 'Flux2SingleTransformerBlock']  |
Quantized Linear Layers: 96    float8_per_row     0 (skipped)                      |
Quantized Linear Layers: 32    float8_per_tensor  0 (skipped)                      |
Quantized Linear Layers: 8     float8_per_block   0 (skipped)                      |
Quantized Linear Layers: 8     float8_weight_only 0 (skipped)                      |
Quantized Linear Layers: 144   (total)                                             |
Skipped   Linear Layers: 0     (total)                                             |
Linear           Layers: 144   (total)                                             |
-----------------------------------------------------------------------------------

INT8/INT4 Quantization

In addition to FP8 quantization, Cache-DiT also supports INT8 and INT4 quantization for weights, which can further reduce the memory footprint of the model. Users can specify int8_per_row, int8_per_tensor, int8_weight_only, or int4_weight_only as the quantization type in the QuantizeConfig when calling the enable_cache API. For example:

import cache_dit
from cache_dit import DBCacheConfig, ParallelismConfig, QuantizeConfig  

cache_dit.enable_cache( 
    # Or "int8_per_tensor", "int8_weight_only", "int4_weight_only", etc.
    pipe, quantize_config=QuantizeConfig(quant_type="int8_per_row"), 
)

INT4 quantization can provide even better memory reduction compared to FP8 or INT8, but it may cause more precision loss. We recommend users to try different quantization types and choose the one that best fits their needs in terms of the trade-off between performance and precision. In most cases, float8 per-row can be a good choice for better memory reduction while maintaining acceptable precision.

Please note that users should also install mslk kernel library to enable INT8/INT4 quantization features. The int4_weight_only w4a16 compute kennel requires architectures >= sm90 (Hopper or newer, TMA required). For older architectures, users can use int8_weight_only quantization for better compatibility.

# stable: mslk, torch and torchao (change cu130 to cu129 if using CUDA 12.9)
uv pip install torch==2.11.0 torchvision torchao triton mslk --index-url https://download.pytorch.org/whl/cu130 --upgrade
# nightly: mslk, torch and torchao (change cu130 to cu129 if using CUDA 12.9)
uv pip install --pre torch torchvision torchao triton mslk --index-url https://download.pytorch.org/whl/nightly/cu130 --upgrade

In the case of distributed inference (context parallelism or tensor parallelism), we recommend users to use float8 quantization to avoid potential compatibility issues.

Nunchaku (W4A4)

Cache-DiT natively supports the Hybrid Cache + Nunchaku + Context Parallelism scheme. Users can leverage caching and context parallelism to speed up Nunchaku 4-bits W4A4 models.

import cache_dit
from diffusers import QwenImagePipeline
from nunchaku import NunchakuQwenImageTransformer2DModel

transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
    f"path-to/svdq-int4_r32-qwen-image.safetensors"
)
pipe = QwenImagePipeline.from_pretrained(
   "Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16,
).to("cuda")

cache_dit.enable_cache(pipe, cache_config=..., parallelism_config=...)