Skip to content

thu-ml/TurboDiffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TurboDiffusion

This repository provides the official implementation of TurboDiffusion, a video generation acceleration framework that can speed up end-to-end diffusion generation by $100 \sim 205\times$ on a single RTX 5090, while maintaining video quality.

Paper: TurboDiffusion: Accelerating Video Diffusion Models by 100--205 Times

Original, E2E Time: 166s
TurboDiffusion, E2E Time: 1.8s
An example of a 5-second video generated by Wan-2.1-T2V-1.3B-480P on a single RTX 5090.

Available Models

Model Name Checkpoint Link Best Resolution
TurboWan2.1-T2V-1.3B-480P Huggingface Model 480p
TurboWan2.1-T2V-14B-480P Huggingface Model 480p
TurboWan2.1-T2V-14B-720P Huggingface Model 720p
TurboWan2.2-I2V-A14B-720P Huggingface Model 720p

Note: All checkpoints support generating videos at 480p or 720p. The "Best Resolution" column indicates the resolution at which the model provides the best video quality.

Installation

Base environment: python>=3.9, torch>=2.7.0. torch==2.8.0 is recommended, as higher versions may cause OOM.

Install TurboDiffusion by pip:

conda create -n turbodiffusion python=3.12
conda activate turbodiffusion

pip install turbodiffusion --no-build-isolation

Or compile from source:

git clone https://github.com/thu-ml/TurboDiffusion.git
cd TurboDiffusion
git submodule update --init --recursive
pip install -e . --no-build-isolation

To enable SageSLA (Sparse-Linear Attention based on SageAttention), install SpargeAttn first:

pip install git+https://github.com/thu-ml/SpargeAttn.git --no-build-isolation

Inference

  1. Download the Wan2.1 VAE and umT5 text encoder checkpoints from the official Wan2.1 repository on Huggingface:

    mkdir checkpoints
    cd checkpoints
    wget https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/resolve/main/Wan2.1_VAE.pth
    wget https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
  2. Download our finetuned checkpoints:

    wget https://huggingface.co/TurboDiffusion/TurboWan2.1-T2V-1.3B-480P/resolve/main/TurboWan2.1-T2V-1.3B-480P.pth

    For RTX 5090, RTX 4090, or similar GPUs, please use the quantized checkpoint:

    wget https://huggingface.co/TurboDiffusion/TurboWan2.1-T2V-1.3B-480P/resolve/main/TurboWan2.1-T2V-1.3B-480P-quant.pth

    For GPUs with a bigger GPU memory than 40GB, e.g., H100, we recommend using the unquantized checkpoint.

    For the I2V model, download both the high-noise and low-noise checkpoints:

    wget https://huggingface.co/TurboDiffusion/TurboWan2.2-I2V-A14B-720P/resolve/main/TurboWan2.2-I2V-A14B-high-720P.pth
    wget https://huggingface.co/TurboDiffusion/TurboWan2.2-I2V-A14B-720P/resolve/main/TurboWan2.2-I2V-A14B-low-720P.pth
  3. Use the inference script for the T2V model:

    export PYTHONPATH=turbodiffusion
    
    # Arguments:
    # --dit_path            Path to the finetuned TurboDiffusion checkpoint
    # --model               Model to use: Wan2.1-1.3B or Wan2.1-14B (default: Wan2.1-1.3B)
    # --num_samples         Number of videos to generate (default: 1)
    # --num_steps           Sampling steps, 1–4 (default: 4)
    # --sigma_max           Initial sigma for rCM (default: 80); larger choices (e.g., 1600) reduce diversity but may enhance quality
    # --vae_path            Path to Wan2.1 VAE (default: checkpoints/Wan2.1_VAE.pth)
    # --text_encoder_path   Path to umT5 text encoder (default: checkpoints/models_t5_umt5-xxl-enc-bf16.pth)
    # --num_frames          Number of frames to generate (default: 77)
    # --prompt              Text prompt for video generation
    # --resolution          Output resolution: "480p" or "720p" (default: 480p)
    # --aspect_ratio        Aspect ratio in W:H format (default: 16:9)
    # --seed                Random seed for reproducibility (default: 0)
    # --save_path           Output file path including extension (default: output/generated_video.mp4)
    # --attention_type      Attention module to use: original, sla or sagesla (default: sagesla)
    # --sla_topk            Top-k ratio for SLA/SageSLA attention (default: 0.1), we recommend using 0.15 for better video quality
    # --quant_linear        Enable quantization for linear layers, pass this if using a quantized checkpoint
    # --default_norm        Use the original LayerNorm and RMSNorm of Wan models
    
    python turbodiffusion/inference/wan2.1_t2v_infer.py \
        --model Wan2.1-1.3B \
        --dit_path checkpoints/TurboWan2.1-T2V-1.3B-480P-quant.pth \
        --resolution 480p \
        --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about." \
        --num_samples 1 \
        --num_steps 4 \
        --quant_linear \
        --attention_type sagesla \
        --sla_topk 0.1

    Or the script for the I2V model:

    export PYTHONPATH=turbodiffusion
    
    # --image_path              Path to the input image
    # --high_noise_model_path   Path to the high noise TurboDiffusion checkpoint
    # --low_noise_model_path    Path to the high noise TurboDiffusion checkpoint
    # --boundary                Timestep boundary for switching from high to low noise model (default: 0.9)
    # --model                   Model to use: Wan2.2-A14B (default: Wan2.2-A14B)
    # --num_samples             Number of videos to generate (default: 1)
    # --num_steps               Sampling steps, 1–4 (default: 4)
    # --sigma_max               Initial sigma for rCM (default: 200); larger choices (e.g., 1600) reduce diversity but may enhance quality
    # --vae_path                Path to Wan2.2 VAE (default: checkpoints/Wan2.2_VAE.pth)
    # --text_encoder_path       Path to umT5 text encoder (default: checkpoints/models_t5_umt5-xxl-enc-bf16.pth)
    # --num_frames              Number of frames to generate (default: 77)
    # --prompt                  Text prompt for video generation
    # --resolution              Output resolution: "480p" or "720p" (default: 720p)
    # --aspect_ratio            Aspect ratio in W:H format (default: 16:9)
    # --adaptive_resolution     Enable adaptive resolution based on input image size
    # --ode                     Use ODE for sampling (sharper but less robust than SDE)
    # --seed                    Random seed for reproducibility (default: 0)
    # --save_path               Output file path including extension (default: output/generated_video.mp4)
    # --attention_type          Attention module to use: original, sla or sagesla (default: sagesla)
    # --sla_topk                Top-k ratio for SLA/SageSLA attention (default: 0.1), we recommend using 0.15 for better video quality
    # --quant_linear            Enable quantization for linear layers, pass this if using a quantized checkpoint
    # --default_norm            Use the original LayerNorm and RMSNorm of Wan models
    
    python turbodiffusion/inference/wan2.2_i2v_infer.py \
        --model Wan2.2-A14B \
        --low_noise_model_path checkpoints/TurboWan2.2-I2V-A14B-low-720P-quant.pth \
        --high_noise_model_path checkpoints/TurboWan2.2-I2V-A14B-high-720P-quant.pth \
        --resolution 720p \
        --adaptive_resolution \
        --image_path assets/i2v_inputs/i2v_input_0.jpg \
        --prompt "POV selfie video, ultra-messy and extremely fast. A white cat in sunglasses stands on a surfboard with a neutral look when the board suddenly whips sideways, throwing cat and camera into the water; the frame dives sharply downward, swallowed by violent bursts of bubbles, spinning turbulence, and smeared water streaks as the camera sinks. Shadows thicken, pressure ripples distort the edges, and loose bubbles rush upward past the lens, showing the camera is still sinking. Then the cat kicks upward with explosive speed, dragging the view through churning bubbles and rapidly brightening water as sunlight floods back in; the camera races upward, water streaming off the lens, and finally breaks the surface in a sudden blast of light and spray, snapping back into a crooked, frantic selfie as the cat resurfaces." \
        --num_samples 1 \
        --num_steps 4 \
        --quant_linear \
        --attention_type sagesla \
        --sla_topk 0.1 \
        --ode

Evaluation

We evaluate video generation on a single RTX 5090 GPU. The E2E Time refers to the end-to-end diffusion generation latency, excluding text encoding and VAE decoding.

Wan-2.2-I2V-A14B-720P

Original, E2E Time: 4183s
TurboDiffusion, E2E Time: 35.4s
Original, E2E Time: 4183s
TurboDiffusion, E2E Time: 35.4s
Original, E2E Time: 4183s
TurboDiffusion, E2E Time: 35.4s
Original, E2E Time: 4183s
TurboDiffusion, E2E Time: 35.4s
Original, E2E Time: 4183s
TurboDiffusion, E2E Time: 35.4s
Original, E2E Time: 4183s
TurboDiffusion, E2E Time: 35.4s
Original, E2E Time: 4183s
TurboDiffusion, E2E Time: 35.4s

Wan-2.1-T2V-1.3B-480P

Original, E2E Time: 166s
FastVideo, E2E Time: 4.7s
TurboDiffusion, E2E Time: 1.8s
Original, E2E Time: 166s
FastVideo, E2E Time: 4.7s
TurboDiffusion, E2E Time: 1.8s
Original, E2E Time: 166s
FastVideo, E2E Time: 4.7s
TurboDiffusion, E2E Time: 1.8s
Original, E2E Time: 166s
FastVideo, E2E Time: 4.7s
TurboDiffusion, E2E Time: 1.8s
Original, E2E Time: 166s
FastVideo, E2E Time: 4.7s
TurboDiffusion, E2E Time: 1.8s
Original, E2E Time: 166s
FastVideo, E2E Time: 4.7s
TurboDiffusion, E2E Time: 1.8s
Original, E2E Time: 166s
FastVideo, E2E Time: 4.7s
TurboDiffusion, E2E Time: 1.8s
Original, E2E Time: 166s
FastVideo, E2E Time: 4.7s
TurboDiffusion, E2E Time: 1.8s

Wan-2.1-T2V-14B-720P

Original, E2E Time: 4648s
FastVideo, E2E Time: 63.5s
TurboDiffusion, E2E Time: 22.7s
Original, E2E Time: 4648s
FastVideo, E2E Time: 63.5s
TurboDiffusion, E2E Time: 22.7s
Original, E2E Time: 4648s
FastVideo, E2E Time: 63.5s
TurboDiffusion, E2E Time: 22.7s

Wan-2.1-T2V-14B-480P

Original, E2E Time: 1635s
FastVideo, E2E Time: 23.2s
TurboDiffusion, E2E Time: 9.4s
Original, E2E Time: 1635s
FastVideo, E2E Time: 23.2s
TurboDiffusion, E2E Time: 9.4s
Original, E2E Time: 1635s
FastVideo, E2E Time: 23.2s
TurboDiffusion, E2E Time: 9.4s
Original, E2E Time: 1635s
FastVideo, E2E Time: 23.2s
TurboDiffusion, E2E Time: 9.4s

Training

In this repo, we provide training code based on Wan2.1 and its synthetic data. The training builds on the rCM codebase (https://github.com/NVlabs/rcm), with infrastructure support including FSDP2, Ulysses CP, and selective activation checkpointing (SAC). For rCM training instructions, please refer to the original rCM repository; SLA training guidance is provided here.

Checkpoints Downloading

Download the Wan2.1 pretrained checkpoints in .pth format and VAE/text encoder to assets/checkpoints:

# make sure git lfs is installed
git clone https://huggingface.co/worstcoder/Wan assets/checkpoints

FSDP2 relies on Distributed Checkpoint (DCP) for loading and saving checkpoints. Before training, convert .pth teacher checkpoints to .dcp first:

python -m torch.distributed.checkpoint.format_utils torch_to_dcp assets/checkpoints/Wan2.1-T2V-1.3B.pth assets/checkpoints/Wan2.1-T2V-1.3B.dcp

After training, the saved .dcp checkpoints can be converted to .pth using the script scripts/dcp_to_pth.py.

Dataset Downloading

We provide Wan2.1-14B-synthesized datasets. Download to assets/datasets using:

# make sure git lfs is installed
git clone https://huggingface.co/datasets/worstcoder/Wan_datasets assets/datasets

Start Training

We implement white-box SLA training by aligning the predictions of the SLA-enabled model with those of the full-attention pretrained model. Unlike black-box training, which tunes the pretrained model using diffusion loss, white-box training mitigates distribution shift and is less sensitive to the training data.

Single-node training example:

WORKDIR="/your/path/to/turbodiffusion"
cd $WORKDIR
export PYTHONPATH=turbodiffusion

# the "IMAGINAIRE_OUTPUT_ROOT" environment variable is the path to save experiment output files
export IMAGINAIRE_OUTPUT_ROOT=${WORKDIR}/outputs
CHECKPOINT_ROOT=${WORKDIR}/assets/checkpoints
DATASET_ROOT=${WORKDIR}/assets/datasets/Wan2.1_14B_480p_16:9_Euler-step100_shift-3.0_cfg-5.0_seed-0_250K

# your Wandb information
export WANDB_API_KEY=xxx
export WANDB_ENTITY=xxx

registry=registry_sla
experiment=wan2pt1_1pt3B_res480p_t2v_SLA

torchrun --nproc_per_node=8 \
    -m scripts.train --config=rcm/configs/${registry}.py -- experiment=${experiment} \
        model.config.teacher_ckpt=${CHECKPOINT_ROOT}/Wan2.1-T2V-1.3B.dcp \
        model.config.tokenizer.vae_pth=${CHECKPOINT_ROOT}/Wan2.1_VAE.pth \
        model.config.text_encoder_path=${CHECKPOINT_ROOT}/models_t5_umt5-xxl-enc-bf16.pth \
        model.config.neg_embed_path=${CHECKPOINT_ROOT}/umT5_wan_negative_emb.pt \
        dataloader_train.tar_path_pattern=${DATASET_ROOT}/shard*.tar

Please refer to turbodiffusion/rcm/configs/experiments/sla/wan2pt1_t2v.py for the 14B config or perform modifications as needed.

Model Merging

The parameter updates from SLA training can be merged into rCM checkpoints using turbodiffusion/scripts/merge_models.py, enabling rCM to perform sparse attention inference.

Roadmap

We're actively working on the following features and improvements:

  • Organize and release training code
  • Optimize infrastructure for better parallel
  • vLLM-Omni integration
  • Support for more video generation models
  • Support for autoregressive video generation models
  • More hardware-level operator optimizations

We welcome community members to help maintain and extend TurboDiffusion. We appreciate your interest in contributing. Welcome to join the TurboDiffusion Team and contribute together!

Citation

If you use this code or find our work valuable, please cite:

@article{zhang2025turbodiffusion,
  title={TurboDiffusion: Accelerating Video Diffusion Models by 100--205 Times},
  author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E. and Chen, Jianfei and Zhu, Jun},
  year={2025}
}

@software{turbodiffusion2025,
  title={TurboDiffusion: Accelerating Video Diffusion Models by 100--205 Times},
  author       = {The TurboDiffusion Team},
  url          = {https://github.com/thu-ml/TurboDiffusion},
  year={2025}
}

@inproceedings{zhang2025sageattention,
  title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration}, 
  author={Zhang, Jintao and Wei, Jia and Zhang, Pengle and Zhu, Jun and Chen, Jianfei},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2025}
}

@article{zhang2025sla,
  title={SLA: Beyond Sparsity in Diffusion Transformers via Fine-Tunable Sparse-Linear Attention},
  author={Zhang, Jintao and Wang, Haoxu and Jiang, Kai and Yang, Shuo and Zheng, Kaiwen and Xi, Haocheng and Wang, Ziteng and Zhu, Hongzhou and Zhao, Min and Stoica, Ion and Gonzalez, Joseph E. and Zhu, Jun and Chen, Jianfei},
  journal={arXiv preprint arXiv:2509.24006},
  year={2025}
}

@article{zheng2025rcm,
  title={Large Scale Diffusion Distillation via Score-Regularized Continuous-Time Consistency},
  author={Zheng, Kaiwen and Wang, Yuji and Ma, Qianli and Chen, Huayu and Zhang, Jintao and Balaji, Yogesh and Chen, Jianfei and Liu, Ming-Yu and Zhu, Jun and Zhang, Qinsheng},
  journal={arXiv preprint arXiv:2510.08431},
  year={2025}
}

@inproceedings{zhang2024sageattention2,
  title={Sageattention2: Efficient attention with thorough outlier smoothing and per-thread int4 quantization},
  author={Zhang, Jintao and Huang, Haofeng and Zhang, Pengle and Wei, Jia and Zhu, Jun and Chen, Jianfei},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2025}
}

About

TurboDiffusion: 100–205× Acceleration of Video Diffusion Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages