diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index 1738efd63bb7..8b7e57e91297 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -73,6 +73,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -84,7 +86,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
tests/pipelines/${{ matrix.module }}
@@ -126,6 +128,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -138,7 +142,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
tests/${{ matrix.module }}
@@ -151,7 +155,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v --make-reports=examples_torch_cuda \
+ --make-reports=examples_torch_cuda \
--report-log=examples_torch_cuda.log \
examples/
@@ -190,6 +194,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -198,7 +204,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
- pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -232,6 +238,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -281,6 +289,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -293,7 +303,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -358,6 +368,8 @@ jobs:
uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
uv pip install pytest-reportlog
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -405,6 +417,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip install -U bitsandbytes optimum_quanto
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -531,7 +545,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
-# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
+# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
@@ -587,7 +601,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
-# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
+# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml
index 7081ee518d55..13c228621f5c 100644
--- a/.github/workflows/pr_modular_tests.yml
+++ b/.github/workflows/pr_modular_tests.yml
@@ -109,7 +109,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -120,7 +121,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index 3306ebe43ef7..674e62ff443a 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -115,7 +115,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -126,7 +127,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/pipelines
@@ -134,7 +135,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx and not Dependency" \
+ -k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
@@ -246,7 +247,8 @@ jobs:
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
uv pip install -U tokenizers
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -255,11 +257,11 @@ jobs:
- name: Run fast PyTorch LoRA tests with PEFT
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v \
+ \
--make-reports=tests_peft_main \
tests/lora/
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v \
+ \
--make-reports=tests_models_lora_peft_main \
tests/models/ -k "lora"
diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml
index 6c208ad7cac7..468979d379c1 100644
--- a/.github/workflows/pr_tests_gpu.yml
+++ b/.github/workflows/pr_tests_gpu.yml
@@ -1,4 +1,4 @@
-name: Fast GPU Tests on PR
+name: Fast GPU Tests on PR
on:
pull_request:
@@ -71,7 +71,7 @@ jobs:
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
-
+
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
@@ -131,7 +131,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -149,18 +150,18 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- if [ "${{ matrix.module }}" = "ip_adapters" ]; then
+ if [ "${{ matrix.module }}" = "ip_adapters" ]; then
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- else
+ else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx and $pattern" \
+ -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- fi
+ fi
- name: Failure short reports
if: ${{ failure() }}
@@ -201,7 +202,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -222,11 +224,11 @@ jobs:
run: |
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
if [ -z "$pattern" ]; then
- pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
- --make-reports=tests_torch_cuda_${{ matrix.module }}
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
+ --make-reports=tests_torch_cuda_${{ matrix.module }}
else
- pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
- --make-reports=tests_torch_cuda_${{ matrix.module }}
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
+ --make-reports=tests_torch_cuda_${{ matrix.module }}
fi
- name: Failure short reports
@@ -262,7 +264,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install -e ".[quality,training]"
- name: Environment
@@ -274,7 +277,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
- pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index 58133a7f43df..7b1c441d3dc0 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -76,7 +76,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -87,7 +88,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -128,7 +129,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -141,7 +143,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_cuda_${{ matrix.module }} \
tests/${{ matrix.module }}
@@ -180,7 +182,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
- uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -189,7 +192,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
- pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -230,7 +233,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -273,7 +276,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
- pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml
index ae619d481c48..38cbffaa6315 100644
--- a/.github/workflows/push_tests_fast.yml
+++ b/.github/workflows/push_tests_fast.yml
@@ -70,7 +70,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml
index 484c7a8eeb49..2d6feb592815 100644
--- a/.github/workflows/push_tests_mps.yml
+++ b/.github/workflows/push_tests_mps.yml
@@ -57,7 +57,7 @@ jobs:
HF_HOME: /System/Volumes/Data/mnt/cache
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
- ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
+ ${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml
index 808818beada3..efdd6ea2b651 100644
--- a/.github/workflows/release_tests_fast.yml
+++ b/.github/workflows/release_tests_fast.yml
@@ -84,7 +84,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -137,7 +137,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
tests/${{ matrix.module }}
@@ -187,7 +187,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -240,7 +240,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
- pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -281,7 +281,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -326,7 +326,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
- pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 55fe2a9a379f..24420af8e490 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -22,6 +22,8 @@
title: Reproducibility
- local: using-diffusers/schedulers
title: Schedulers
+ - local: using-diffusers/automodel
+ title: AutoModel
- local: using-diffusers/other-formats
title: Model formats
- local: using-diffusers/push_to_hub
@@ -119,6 +121,8 @@
title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
+ - local: modular_diffusers/custom_blocks
+ title: Building Custom Blocks
title: Modular Diffusers
- isExpanded: false
sections:
@@ -387,6 +391,8 @@
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
+ - local: api/models/wan_animate_transformer_3d
+ title: WanAnimateTransformer3DModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
title: Transformers
@@ -448,6 +454,8 @@
- sections:
- local: api/pipelines/overview
title: Overview
+ - local: api/pipelines/auto_pipeline
+ title: AutoPipeline
- sections:
- local: api/pipelines/audioldm
title: AudioLDM
@@ -460,8 +468,6 @@
- local: api/pipelines/stable_audio
title: Stable Audio
title: Audio
- - local: api/pipelines/auto_pipeline
- title: AutoPipeline
- sections:
- local: api/pipelines/amused
title: aMUSEd
@@ -525,6 +531,8 @@
title: HiDream-I1
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
+ - local: api/pipelines/hunyuanimage21
+ title: HunyuanImage2.1
- local: api/pipelines/pix2pix
title: InstructPix2Pix
- local: api/pipelines/kandinsky
@@ -638,8 +646,6 @@
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- - local: api/pipelines/hunyuanimage21
- title: HunyuanImage2.1
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
diff --git a/docs/source/en/api/models/auto_model.md b/docs/source/en/api/models/auto_model.md
index 376dd12d12c4..aee9b5dbe50c 100644
--- a/docs/source/en/api/models/auto_model.md
+++ b/docs/source/en/api/models/auto_model.md
@@ -12,15 +12,7 @@ specific language governing permissions and limitations under the License.
# AutoModel
-The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
-
-```python
-from diffusers import AutoModel, AutoPipelineForText2Image
-
-unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
-pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
-```
-
+[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file.
## AutoModel
diff --git a/docs/source/en/api/models/wan_animate_transformer_3d.md b/docs/source/en/api/models/wan_animate_transformer_3d.md
new file mode 100644
index 000000000000..cc7b3f0c408c
--- /dev/null
+++ b/docs/source/en/api/models/wan_animate_transformer_3d.md
@@ -0,0 +1,30 @@
+
+
+# WanAnimateTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import WanAnimateTransformer3DModel
+
+transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## WanAnimateTransformer3DModel
+
+[[autodoc]] WanAnimateTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/sana_video.md b/docs/source/en/api/pipelines/sana_video.md
index 85d77fb2944b..d69f4a95facc 100644
--- a/docs/source/en/api/pipelines/sana_video.md
+++ b/docs/source/en/api/pipelines/sana_video.md
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
-# SanaVideoPipeline
+# Sana-Video

@@ -37,6 +37,85 @@ Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-vi
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
+
+## Generation Pipelines
+
+
`
+
+
+The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame.
+
+```python
+model_id =
+pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16)
+pipe.text_encoder.to(torch.bfloat16)
+pipe.vae.to(torch.float32)
+pipe.to("cuda")
+
+prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+motion_scale = 30
+motion_prompt = f" motion score: {motion_scale}."
+prompt = prompt + motion_prompt
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=832,
+ frames=81,
+ guidance_scale=6,
+ num_inference_steps=50,
+ generator=torch.Generator(device="cuda").manual_seed(0),
+).frames[0]
+
+export_to_video(video, "sana_video.mp4", fps=16)
+```
+
+
+
+
+The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame.
+
+```python
+model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
+pipe = SanaImageToVideoPipeline.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+)
+pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
+pipe.vae.to(torch.float32)
+pipe.text_encoder.to(torch.bfloat16)
+pipe.to("cuda")
+
+image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
+prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
+negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+motion_scale = 30
+motion_prompt = f" motion score: {motion_scale}."
+prompt = prompt + motion_prompt
+
+motion_scale = 30.0
+
+video = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=832,
+ frames=81,
+ guidance_scale=6,
+ num_inference_steps=50,
+ generator=torch.Generator(device="cuda").manual_seed(0),
+).frames[0]
+
+export_to_video(video, "sana-i2v.mp4", fps=16)
+```
+
+
+
+
+
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
@@ -97,6 +176,13 @@ export_to_video(output, "sana-video-output.mp4", fps=16)
- __call__
+## SanaImageToVideoPipeline
+
+[[autodoc]] SanaImageToVideoPipeline
+ - all
+ - __call__
+
+
## SanaVideoPipelineOutput
-[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
+[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
index 3289a840e2b1..0fc17bc6ea26 100644
--- a/docs/source/en/api/pipelines/wan.md
+++ b/docs/source/en/api/pipelines/wan.md
@@ -40,6 +40,8 @@ The following Wan models are supported in Diffusers:
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
+- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
+- [Wan 2.2 S2V 14B](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B-Diffusers)
> [!TIP]
> Click on the Wan models in the right sidebar for more examples of video generation.
@@ -95,15 +97,15 @@ pipeline = WanPipeline.from_pretrained(
pipeline.to("cuda")
prompt = """
-The camera rushes from far to near in a low-angle shot,
-revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
-for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
-Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+The camera rushes from far to near in a low-angle shot,
+revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
-Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
-low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -150,15 +152,15 @@ pipeline.transformer = torch.compile(
)
prompt = """
-The camera rushes from far to near in a low-angle shot,
-revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
-for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
-Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+The camera rushes from far to near in a low-angle shot,
+revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
-Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
-low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -236,6 +238,7 @@ export_to_video(output, "output.mp4", fps=16)
+
### Any-to-Video Controllable Generation
Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include:
@@ -249,6 +252,330 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
+
+
+
+### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
+
+[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
+
+*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
+
+The project page: https://humanaigc.github.io/wan-animate
+
+This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
+
+#### Usage
+
+The Wan-Animate pipeline supports two modes of operation:
+
+1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
+2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
+
+##### Prerequisites
+
+Before using the pipeline, you need to preprocess your reference video to extract:
+- **Pose video**: Contains skeletal keypoints representing body motion
+- **Face video**: Contains facial feature representations for expression control
+
+For replacement mode, you additionally need:
+- **Background video**: The original video containing the scene
+- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
+
+> [!NOTE]
+> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release.
+
+The example below demonstrates how to use the Wan-Animate pipeline:
+
+
+
+
+```python
+import numpy as np
+import torch
+from diffusers import AutoencoderKLWan, WanAnimatePipeline
+from diffusers.utils import export_to_video, load_image, load_video
+
+model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Load character image and preprocessed videos
+image = load_image("path/to/character.jpg")
+pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
+face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
+
+# Resize image to match VAE constraints
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
+
+image, height, width = aspect_ratio_resize(image, pipe)
+
+prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
+negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
+
+# Generate animated video
+output = pipe(
+ image=image,
+ pose_video=pose_video,
+ face_video=face_video,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ segment_frame_length=77,
+ guidance_scale=1.0,
+ mode="animate", # Animation mode (default)
+).frames[0]
+export_to_video(output, "animated_character.mp4", fps=30)
+```
+
+
+
+
+```python
+import numpy as np
+import torch
+from diffusers import AutoencoderKLWan, WanAnimatePipeline
+from diffusers.utils import export_to_video, load_image, load_video
+
+model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Load all required inputs for replacement mode
+image = load_image("path/to/new_character.jpg")
+pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
+face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
+background_video = load_video("path/to/background_video.mp4") # Original scene
+mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
+
+# Resize image to match video dimensions
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
+
+image, height, width = aspect_ratio_resize(image, pipe)
+
+prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
+negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
+
+# Replace character in background video
+output = pipe(
+ image=image,
+ pose_video=pose_video,
+ face_video=face_video,
+ background_video=background_video,
+ mask_video=mask_video,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ segment_frame_lengths=77,
+ guidance_scale=1.0,
+ mode="replace", # Replacement mode
+).frames[0]
+export_to_video(output, "character_replaced.mp4", fps=30)
+```
+
+
+
+
+```python
+import numpy as np
+import torch
+from diffusers import AutoencoderKLWan, WanAnimatePipeline
+from diffusers.utils import export_to_video, load_image, load_video
+
+model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+image = load_image("path/to/character.jpg")
+pose_video = load_video("path/to/pose_video.mp4")
+face_video = load_video("path/to/face_video.mp4")
+
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
+
+image, height, width = aspect_ratio_resize(image, pipe)
+
+prompt = "A person dancing energetically in a studio"
+negative_prompt = "blurry, low quality"
+
+# Advanced: Use temporal guidance and custom callback
+def callback_fn(pipe, step_index, timestep, callback_kwargs):
+ # You can modify latents or other tensors here
+ print(f"Step {step_index}, Timestep {timestep}")
+ return callback_kwargs
+
+output = pipe(
+ image=image,
+ pose_video=pose_video,
+ face_video=face_video,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ segment_frame_length=77,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ prev_segment_conditioning_frames=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
+ callback_on_step_end=callback_fn,
+ callback_on_step_end_tensor_inputs=["latents"],
+).frames[0]
+export_to_video(output, "animated_advanced.mp4", fps=30)
+```
+
+
+
+
+#### Key Parameters
+
+- **mode**: Choose between `"animate"` (default) or `"replace"`
+- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
+- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.)
+
+
+### Wan-S2V: Audio-Driven Cinematic Video Generation
+
+[Wan-S2V](https://huggingface.co/papers/2508.18621) by the Wan Team.
+
+*Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.*
+
+The project page: https://humanaigc.github.io/wan-s2v-webpage/
+
+This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
+
+The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, an audio, and a pose video.
+
+
+
+
+```python
+import numpy as np, math
+import torch
+from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline
+from diffusers.utils import export_to_merged_video_audio, load_image, load_audio, load_video, export_to_video
+from transformers import Wav2Vec2ForCTC
+import requests
+from PIL import Image
+from io import BytesIO
+
+
+model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers"
+audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = WanSpeechToVideoPipeline.from_pretrained(
+ model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+headers = {"User-Agent": "Mozilla/5.0"}
+url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg"
+resp = requests.get(url, headers=headers, timeout=30)
+image = Image.open(BytesIO(resp.content))
+
+audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3")
+#pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4"
+
+def get_size_less_than_area(height,
+ width,
+ target_area=1024 * 704,
+ divisor=64):
+ if height * width <= target_area:
+ # If the original image area is already less than or equal to the target,
+ # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
+ max_upper_area = target_area
+ min_scale = 0.1
+ max_scale = 1.0
+ else:
+ # Resize to fit within the target area and then pad to multiples of `divisor`
+ max_upper_area = target_area # Maximum allowed total pixel count after padding
+ d = divisor - 1
+ b = d * (height + width)
+ a = height * width
+ c = d**2 - max_upper_area
+
+ # Calculate scale boundaries using quadratic equation
+ min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied
+ max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding
+
+ # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
+ # Use binary search-like iteration to find this scale
+ find_it = False
+ for i in range(100):
+ scale = max_scale - (max_scale - min_scale) * i / 100
+ new_height, new_width = int(height * scale), int(width * scale)
+
+ # Pad to make dimensions divisible by 64
+ pad_height = (64 - new_height % 64) % 64
+ pad_width = (64 - new_width % 64) % 64
+ pad_top = pad_height // 2
+ pad_bottom = pad_height - pad_top
+ pad_left = pad_width // 2
+ pad_right = pad_width - pad_left
+
+ padded_height, padded_width = new_height + pad_height, new_width + pad_width
+
+ if padded_height * padded_width <= max_upper_area:
+ find_it = True
+ break
+
+ if find_it:
+ return padded_height, padded_width
+ else:
+ # Fallback: calculate target dimensions based on aspect ratio and divisor alignment
+ aspect_ratio = width / height
+ target_width = int(
+ (target_area * aspect_ratio)**0.5 // divisor * divisor)
+ target_height = int(
+ (target_area / aspect_ratio)**0.5 // divisor * divisor)
+
+ # Ensure the result is not larger than the original resolution
+ if target_width >= width or target_height >= height:
+ target_width = int(width // divisor * divisor)
+ target_height = int(height // divisor * divisor)
+
+ return target_height, target_width
+
+height, width = get_size_less_than_area(first_frame.height, first_frame.width, 480*832)
+
+prompt = "Einstein singing a song."
+
+output = pipe(
+ prompt=prompt, image=image, audio=audio, sampling_rate=sampling_rate,
+ height=height, width=width, num_frames_per_chunk=80,
+ #pose_video_path_or_url=pose_video_path_or_url,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
+
+# Lastly, we need to merge the video and audio into a new video, with the duration set to
+# the shorter of the two and overwrite the original video file.
+export_to_merged_video_audio("output.mp4", "audio.mp3")
+```
+
+
+
+
+
## Notes
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
@@ -281,10 +608,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
# use "steamboat willie style" to trigger the LoRA
prompt = """
- steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
- revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
- for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
- Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+ steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
+ revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+ for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+ Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
@@ -359,6 +686,18 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- all
- __call__
+## WanAnimatePipeline
+
+[[autodoc]] WanAnimatePipeline
+ - all
+ - __call__
+
+## WanSpeechToVideoPipeline
+
+[[autodoc]] WanSpeechToVideoPipeline
+ - all
+ - __call__
+
## WanPipelineOutput
-[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
\ No newline at end of file
+[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
diff --git a/docs/source/en/modular_diffusers/custom_blocks.md b/docs/source/en/modular_diffusers/custom_blocks.md
new file mode 100644
index 000000000000..1c311582264e
--- /dev/null
+++ b/docs/source/en/modular_diffusers/custom_blocks.md
@@ -0,0 +1,492 @@
+
+
+
+# Building Custom Blocks
+
+[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
+
+> [!TIP]
+> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
+
+## Project Structure
+
+Your custom block project should use the following structure:
+
+```shell
+.
+├── block.py
+└── modular_config.json
+```
+
+- `block.py` contains the custom block implementation
+- `modular_config.json` contains the metadata needed to load the block
+
+## Example: Florence 2 Inpainting Block
+
+In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
+
+The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
+
+```py
+# Inside block.py
+from diffusers.modular_pipelines import (
+ ModularPipelineBlocks,
+ ComponentSpec,
+)
+from transformers import AutoProcessor, Florence2ForConditionalGeneration
+
+
+class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
+
+ @property
+ def expected_components(self):
+ return [
+ ComponentSpec(
+ name="image_annotator",
+ type_hint=Florence2ForConditionalGeneration,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ComponentSpec(
+ name="image_annotator_processor",
+ type_hint=AutoProcessor,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ]
+```
+
+Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
+
+```py
+from typing import List, Union
+from PIL import Image, ImageDraw
+import torch
+import numpy as np
+
+from diffusers.modular_pipelines import (
+ PipelineState,
+ ModularPipelineBlocks,
+ InputParam,
+ ComponentSpec,
+ OutputParam,
+)
+from transformers import AutoProcessor, Florence2ForConditionalGeneration
+
+
+class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
+
+ @property
+ def expected_components(self):
+ return [
+ ComponentSpec(
+ name="image_annotator",
+ type_hint=Florence2ForConditionalGeneration,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ComponentSpec(
+ name="image_annotator_processor",
+ type_hint=AutoProcessor,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "image",
+ type_hint=Union[Image.Image, List[Image.Image]],
+ required=True,
+ description="Image(s) to annotate",
+ ),
+ InputParam(
+ "annotation_task",
+ type_hint=Union[str, List[str]],
+ required=True,
+ default="
",
+ description="""Annotation Task to perform on the image.
+ Supported Tasks:
+
+
+
+
+
+
+
+
+
+
+ """,
+ ),
+ InputParam(
+ "annotation_prompt",
+ type_hint=Union[str, List[str]],
+ required=True,
+ description="""Annotation Prompt to provide more context to the task.
+ Can be used to detect or segment out specific elements in the image
+ """,
+ ),
+ InputParam(
+ "annotation_output_type",
+ type_hint=str,
+ required=True,
+ default="mask_image",
+ description="""Output type from annotation predictions. Availabe options are
+ mask_image:
+ -black and white mask image for the given image based on the task type
+ mask_overlay:
+ - mask overlayed on the original image
+ bounding_box:
+ - bounding boxes drawn on the original image
+ """,
+ ),
+ InputParam(
+ "annotation_overlay",
+ type_hint=bool,
+ required=True,
+ default=False,
+ description="",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "mask_image",
+ type_hint=Image,
+ description="Inpainting Mask for input Image(s)",
+ ),
+ OutputParam(
+ "annotations",
+ type_hint=dict,
+ description="Annotations Predictions for input Image(s)",
+ ),
+ OutputParam(
+ "image",
+ type_hint=Image,
+ description="Annotated input Image(s)",
+ ),
+ ]
+
+```
+
+Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
+
+```py
+from typing import List, Union
+from PIL import Image, ImageDraw
+import torch
+import numpy as np
+
+from diffusers.modular_pipelines import (
+ PipelineState,
+ ModularPipelineBlocks,
+ InputParam,
+ ComponentSpec,
+ OutputParam,
+)
+from transformers import AutoProcessor, Florence2ForConditionalGeneration
+
+
+class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
+
+ @property
+ def expected_components(self):
+ return [
+ ComponentSpec(
+ name="image_annotator",
+ type_hint=Florence2ForConditionalGeneration,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ComponentSpec(
+ name="image_annotator_processor",
+ type_hint=AutoProcessor,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "image",
+ type_hint=Union[Image.Image, List[Image.Image]],
+ required=True,
+ description="Image(s) to annotate",
+ ),
+ InputParam(
+ "annotation_task",
+ type_hint=Union[str, List[str]],
+ required=True,
+ default="",
+ description="""Annotation Task to perform on the image.
+ Supported Tasks:
+
+
+
+
+
+
+
+
+
+
+ """,
+ ),
+ InputParam(
+ "annotation_prompt",
+ type_hint=Union[str, List[str]],
+ required=True,
+ description="""Annotation Prompt to provide more context to the task.
+ Can be used to detect or segment out specific elements in the image
+ """,
+ ),
+ InputParam(
+ "annotation_output_type",
+ type_hint=str,
+ required=True,
+ default="mask_image",
+ description="""Output type from annotation predictions. Availabe options are
+ mask_image:
+ -black and white mask image for the given image based on the task type
+ mask_overlay:
+ - mask overlayed on the original image
+ bounding_box:
+ - bounding boxes drawn on the original image
+ """,
+ ),
+ InputParam(
+ "annotation_overlay",
+ type_hint=bool,
+ required=True,
+ default=False,
+ description="",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "mask_image",
+ type_hint=Image,
+ description="Inpainting Mask for input Image(s)",
+ ),
+ OutputParam(
+ "annotations",
+ type_hint=dict,
+ description="Annotations Predictions for input Image(s)",
+ ),
+ OutputParam(
+ "image",
+ type_hint=Image,
+ description="Annotated input Image(s)",
+ ),
+ ]
+
+ def get_annotations(self, components, images, prompts, task):
+ task_prompts = [task + prompt for prompt in prompts]
+
+ inputs = components.image_annotator_processor(
+ text=task_prompts, images=images, return_tensors="pt"
+ ).to(components.image_annotator.device, components.image_annotator.dtype)
+
+ generated_ids = components.image_annotator.generate(
+ input_ids=inputs["input_ids"],
+ pixel_values=inputs["pixel_values"],
+ max_new_tokens=1024,
+ early_stopping=False,
+ do_sample=False,
+ num_beams=3,
+ )
+ annotations = components.image_annotator_processor.batch_decode(
+ generated_ids, skip_special_tokens=False
+ )
+ outputs = []
+ for image, annotation in zip(images, annotations):
+ outputs.append(
+ components.image_annotator_processor.post_process_generation(
+ annotation, task=task, image_size=(image.width, image.height)
+ )
+ )
+ return outputs
+
+ def prepare_mask(self, images, annotations, overlay=False, fill="white"):
+ masks = []
+ for image, annotation in zip(images, annotations):
+ mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
+ draw = ImageDraw.Draw(mask_image)
+
+ for _, _annotation in annotation.items():
+ if "polygons" in _annotation:
+ for polygon in _annotation["polygons"]:
+ polygon = np.array(polygon).reshape(-1, 2)
+ if len(polygon) < 3:
+ continue
+ polygon = polygon.reshape(-1).tolist()
+ draw.polygon(polygon, fill=fill)
+
+ elif "bbox" in _annotation:
+ bbox = _annotation["bbox"]
+ draw.rectangle(bbox, fill="white")
+
+ masks.append(mask_image)
+
+ return masks
+
+ def prepare_bounding_boxes(self, images, annotations):
+ outputs = []
+ for image, annotation in zip(images, annotations):
+ image_copy = image.copy()
+ draw = ImageDraw.Draw(image_copy)
+ for _, _annotation in annotation.items():
+ bbox = _annotation["bbox"]
+ label = _annotation["label"]
+
+ draw.rectangle(bbox, outline="red", width=3)
+ draw.text((bbox[0], bbox[1] - 20), label, fill="red")
+
+ outputs.append(image_copy)
+
+ return outputs
+
+ def prepare_inputs(self, images, prompts):
+ prompts = prompts or ""
+
+ if isinstance(images, Image.Image):
+ images = [images]
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ if len(images) != len(prompts):
+ raise ValueError("Number of images and annotation prompts must match.")
+
+ return images, prompts
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ images, annotation_task_prompt = self.prepare_inputs(
+ block_state.image, block_state.annotation_prompt
+ )
+ task = block_state.annotation_task
+ fill = block_state.fill
+
+ annotations = self.get_annotations(
+ components, images, annotation_task_prompt, task
+ )
+ block_state.annotations = annotations
+ if block_state.annotation_output_type == "mask_image":
+ block_state.mask_image = self.prepare_mask(images, annotations)
+ else:
+ block_state.mask_image = None
+
+ if block_state.annotation_output_type == "mask_overlay":
+ block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)
+
+ elif block_state.annotation_output_type == "bounding_box":
+ block_state.image = self.prepare_bounding_boxes(images, annotations)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+```
+
+Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
+
+
+
+
+```shell
+# In the folder with the `block.py` file, run:
+diffusers-cli custom_block
+```
+
+Then upload the block to the Hub:
+
+```shell
+hf upload . .
+```
+
+
+
+```py
+from block import Florence2ImageAnnotatorBlock
+block = Florence2ImageAnnotatorBlock()
+block.push_to_hub("")
+```
+
+
+
+
+## Using Custom Blocks
+
+Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
+from diffusers.utils import load_image
+
+# Fetch the Florence2 image annotator block that will create our mask
+image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
+
+my_blocks = INPAINT_BLOCKS.copy()
+# insert the annotation block before the image encoding step
+my_blocks.insert("image_annotator", image_annotator_block, 1)
+
+# Create our initial set of inpainting blocks
+blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
+
+repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
+pipe = blocks.init_pipeline(repo_id)
+pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
+image = image.resize((1024, 1024))
+
+prompt = ["A red car"]
+annotation_task = ""
+annotation_prompt = ["the car"]
+
+output = pipe(
+ prompt=prompt,
+ image=image,
+ annotation_task=annotation_task,
+ annotation_prompt=annotation_prompt,
+ annotation_output_type="mask_image",
+ num_inference_steps=35,
+ guidance_scale=7.5,
+ strength=0.95,
+ output="images"
+)
+output[0].save("florence-inpainting.png")
+```
+
+## Editing Custom Blocks
+
+By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
+from diffusers.utils import load_image
+
+# Fetch the Florence2 image annotator block that will create our mask
+image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
+```
+
+Any changes made to the block files in this folder will be reflected when you load the block again.
diff --git a/docs/source/en/using-diffusers/automodel.md b/docs/source/en/using-diffusers/automodel.md
new file mode 100644
index 000000000000..957cbd17e3f7
--- /dev/null
+++ b/docs/source/en/using-diffusers/automodel.md
@@ -0,0 +1,46 @@
+
+
+# AutoModel
+
+The [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries.
+
+The example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from.
+
+```py
+import torch
+from diffusers import AutoModel, DiffusionPipeline
+
+transformer = AutoModel.from_pretrained(
+ "Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+
+text_encoder = AutoModel.from_pretrained(
+ "Qwen/Qwen-Image", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+```
+
+[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
+
+```py
+import torch
+from diffusers import AutoModel
+
+transformer = AutoModel.from_pretrained(
+ "custom/custom-transformer-model", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
+)
+```
+
+If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
+
+> [!NOTE]
+> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
\ No newline at end of file
diff --git a/examples/community/README.md b/examples/community/README.md
index 4a4b0f5fd9f5..4ff9c4d77704 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
-
+| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -5488,7 +5488,7 @@ Editing at Scale", many thanks to their contribution!
This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
-As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
+As explained in Section 3 of [the paper](https://huggingface.co/papers/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
## Example Usage
@@ -5527,3 +5527,106 @@ images = pipe(
).images
images[0].save("pizzeria.png")
```
+
+# Flux Fill ControlNet Pipeline
+
+This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.
+
+While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.
+
+## Example Usage
+
+
+```python
+import torch
+from diffusers import (
+ FluxControlNetModel,
+ FluxPriorReduxPipeline,
+)
+from diffusers.utils import load_image
+
+# NEW PIPELINE (updated name)
+from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+dtype = torch.bfloat16
+
+# Models
+base_model = "black-forest-labs/FLUX.1-Fill-dev"
+controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
+prior_model = "black-forest-labs/FLUX.1-Redux-dev"
+
+# Load ControlNet
+controlnet = FluxControlNetModel.from_pretrained(
+ controlnet_model,
+ torch_dtype=dtype,
+)
+
+# Load Fill + ControlNet Pipeline
+fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(
+ base_model,
+ controlnet=controlnet,
+ torch_dtype=dtype,
+).to(device)
+
+# OPTIONAL FP8
+# fill_pipe.transformer.enable_layerwise_casting(
+# storage_dtype=torch.float8_e4m3fn,
+# compute_dtype=torch.bfloat16
+# )
+
+# OPTIONAL Prior Redux
+#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
+# prior_model,
+# torch_dtype=dtype,
+#).to(device)
+
+# Inputs
+
+# combined_image = load_image("person_input.png")
+
+
+# 1. Prior conditioning
+#prior_out = pipe_prior_redux(
+# image=cloth_image,
+# prompt=cloth_prompt,
+#)
+
+# 2. Fill Inpaint with ControlNet
+
+# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
+
+img = load_image(r"imgs/background.jpg")
+mask = load_image(r"imgs/mask.png")
+
+control_image_depth = load_image(r"imgs/dog_depth _2.png")
+
+result = fill_pipe(
+ prompt="a dog on a bench",
+ image=img,
+ mask_image=mask,
+
+ control_image=control_image_depth,
+ control_mode=[2], # union mode
+ control_guidance_start=0.0,
+ control_guidance_end=0.8,
+ controlnet_conditioning_scale=0.9,
+
+ height=1024,
+ width=1024,
+
+ strength=1.0,
+ guidance_scale=50.0,
+ num_inference_steps=60,
+ max_sequence_length=512,
+
+# **prior_out,
+)
+
+# result.images[0].save("flux_fill_controlnet_inpaint.png")
+
+from datetime import datetime
+timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg")
+```
+
diff --git a/examples/community/pipline_flux_fill_controlnet_Inpaint.py b/examples/community/pipline_flux_fill_controlnet_Inpaint.py
new file mode 100644
index 000000000000..6b1c204df03b
--- /dev/null
+++ b/examples/community/pipline_flux_fill_controlnet_Inpaint.py
@@ -0,0 +1,1319 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from diffusers.models.transformers import FluxTransformer2DModel
+from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxControlNetInpaintPipeline
+ >>> from diffusers.models import FluxControlNetModel
+ >>> from diffusers.utils import load_image
+
+ >>> controlnet = FluxControlNetModel.from_pretrained(
+ ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16
+ ... )
+ >>> pipe = FluxControlNetInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> control_image = load_image(
+ ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
+ ... )
+ >>> init_image = load_image(
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ ... )
+ >>> mask_image = load_image(
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ ... )
+
+ >>> prompt = "A girl holding a sign that says InstantX"
+ >>> image = pipe(
+ ... prompt,
+ ... image=init_image,
+ ... mask_image=mask_image,
+ ... control_image=control_image,
+ ... control_guidance_start=0.2,
+ ... control_guidance_end=0.8,
+ ... controlnet_conditioning_scale=0.7,
+ ... strength=0.7,
+ ... num_inference_steps=28,
+ ... guidance_scale=3.5,
+ ... ).images[0]
+ >>> image.save("flux_controlnet_inpaint.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def retrieve_latents_fill(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ The Flux controlnet pipeline for inpainting.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image", "mask", "masked_image_latents"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ controlnet: Union[
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
+ ],
+ ):
+ super().__init__()
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = FluxMultiControlNetModel(controlnet)
+
+ self.register_modules(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ controlnet=controlnet,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device)
+ latents = noise
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, noise, image_latents, latent_image_ids
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_mask_latents_fill(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # 1. calculate the height and width of the latents
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ # 2. encode the masked image
+ if masked_image.shape[1] == num_channels_latents:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents_fill(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ batch_size = batch_size * num_images_per_prompt
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # 4. pack the masked_image_latents
+ # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ # 5.resize mask to latents shape we we concatenate the mask to the latents
+ mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed)
+ mask = mask.view(
+ batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor
+ ) # batch_size, height, 8, width, 8
+ mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width
+ mask = mask.reshape(
+ batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width
+ ) # batch_size, 8*8, height, width
+
+ # 6. pack the mask:
+ # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2
+ mask = self._pack_latents(
+ mask,
+ batch_size,
+ self.vae_scale_factor * self.vae_scale_factor,
+ height,
+ width,
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ padding_mask_crop: Optional[int] = None,
+ sigmas: Optional[List[float]] = None,
+ num_inference_steps: int = 28,
+ guidance_scale: float = 7.0,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_mode: Optional[Union[int, List[int]]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The image(s) to inpaint.
+ mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels
+ will be preserved.
+ masked_image_latents (`torch.FloatTensor`, *optional*):
+ Pre-generated masked image latents.
+ control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The ControlNet input condition. Image to control the generation.
+ height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 0.6):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1.
+ padding_mask_crop (`int`, *optional*):
+ The size of the padding to use when cropping the mask.
+ num_inference_steps (`int`, *optional*, defaults to 28):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ control_mode (`int` or `List[int]`, *optional*):
+ The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original transformer.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
+ make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ Additional keyword arguments to be passed to the joint attention mechanism.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising step during the inference.
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ The maximum length of the sequence to be generated.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ global_height = height
+ global_width = width
+
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(
+ mask_image, global_width, global_height, pad=padding_mask_crop
+ )
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 5. Prepare control image
+ # num_channels_latents = self.transformer.config.in_channels // 4
+ num_channels_latents = self.vae.config.latent_channels
+
+ if isinstance(self.controlnet, FluxControlNetModel):
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image.shape[-2:]
+
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
+ controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
+ if self.controlnet.input_hint_block is None:
+ # vae encode
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ # set control mode
+ if control_mode is not None:
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
+ control_mode = control_mode.reshape([-1, 1])
+
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
+ control_images = []
+
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
+ controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
+ for i, control_image_ in enumerate(control_image):
+ control_image_ = self.prepare_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image_.shape[-2:]
+
+ if self.controlnet.nets[0].input_hint_block is None:
+ # vae encode
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image_.shape[2:]
+ control_image_ = self._pack_latents(
+ control_image_,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+
+ # set control mode
+ control_mode_ = []
+ if isinstance(control_mode, list):
+ for cmode in control_mode:
+ if cmode is None:
+ control_mode_.append(-1)
+ else:
+ control_mode_.append(cmode)
+ control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
+ control_mode = control_mode.reshape([-1, 1])
+
+ # 6. Prepare timesteps
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (
+ int(global_width) // self.vae_scale_factor // 2
+ )
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 7. Prepare latent variables
+
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ global_height,
+ global_width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Prepare mask latents
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ global_height,
+ global_width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ mask_imagee = self.mask_processor.preprocess(mask_image, height=height, width=width)
+ masked_imagee = init_image * (1 - mask_imagee)
+ masked_imagee = masked_imagee.to(dtype=self.vae.dtype, device=device)
+ maskkk, masked_image_latentsss = self.prepare_mask_latents_fill(
+ mask_imagee,
+ masked_imagee,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
+
+ # 9. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # predict the noise residual
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
+ else:
+ use_guidance = self.controlnet.config.guidance_embeds
+ if use_guidance:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image,
+ controlnet_mode=control_mode,
+ conditioning_scale=cond_scale,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )
+
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ masked_image_latents_fill = torch.cat((masked_image_latentsss, maskkk), dim=-1)
+ latent_model_input = torch.cat([latents, masked_image_latents_fill], dim=2)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # For inpainting, we need to apply the mask and add the masked image latents
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ control_image = callback_outputs.pop("control_image", control_image)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # Post-processing
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py
index bca67e3959d8..3bc780cfcf7a 100644
--- a/examples/community/regional_prompting_stable_diffusion.py
+++ b/examples/community/regional_prompting_stable_diffusion.py
@@ -490,7 +490,7 @@ def hook_forwards(root_module: torch.nn.Module):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -841,7 +841,7 @@ def stable_diffusion_call(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
@@ -872,7 +872,7 @@ def stable_diffusion_call(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -1062,7 +1062,7 @@ def stable_diffusion_call(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
@@ -1668,7 +1668,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md
index 242f018b654b..42edbb122136 100644
--- a/examples/dreambooth/README_flux.md
+++ b/examples/dreambooth/README_flux.md
@@ -268,12 +268,11 @@ provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_f
**important**
> [!NOTE]
-> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
+> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source.
> To do this, execute the following steps in a new virtual environment:
> ```
> git clone https://github.com/huggingface/diffusers
> cd diffusers
-> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
> pip install -e .
> ```
diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py
index fbb7c1d9e706..a939a06cbd46 100644
--- a/scripts/convert_sana_video_to_diffusers.py
+++ b/scripts/convert_sana_video_to_diffusers.py
@@ -80,6 +80,8 @@ def main(args):
# scheduler
flow_shift = 8.0
+ if args.task == "i2v":
+ assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
# model config
layer_num = 20
@@ -312,6 +314,7 @@ def main(args):
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
+ parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py
index 39a364b07d78..30324f143e38 100644
--- a/scripts/convert_wan_to_diffusers.py
+++ b/scripts/convert_wan_to_diffusers.py
@@ -6,13 +6,26 @@
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
-from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
+from transformers import (
+ AutoProcessor,
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionModel,
+ CLIPVisionModelWithProjection,
+ UMT5EncoderModel,
+ Wav2Vec2ForCTC,
+ Wav2Vec2Processor,
+)
from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
+ WanAnimatePipeline,
+ WanAnimateTransformer3DModel,
WanImageToVideoPipeline,
WanPipeline,
+ WanS2VTransformer3DModel,
+ WanSpeechToVideoPipeline,
WanTransformer3DModel,
WanVACEPipeline,
WanVACETransformer3DModel,
@@ -105,8 +118,254 @@
"after_proj": "proj_out",
}
+ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # Add attention component mappings
+ "self_attn.q": "attn1.to_q",
+ "self_attn.k": "attn1.to_k",
+ "self_attn.v": "attn1.to_v",
+ "self_attn.o": "attn1.to_out.0",
+ "self_attn.norm_q": "attn1.norm_q",
+ "self_attn.norm_k": "attn1.norm_k",
+ "cross_attn.q": "attn2.to_q",
+ "cross_attn.k": "attn2.to_k",
+ "cross_attn.v": "attn2.to_v",
+ "cross_attn.o": "attn2.to_out.0",
+ "cross_attn.norm_q": "attn2.norm_q",
+ "cross_attn.norm_k": "attn2.norm_k",
+ "cross_attn.k_img": "attn2.to_k_img",
+ "cross_attn.v_img": "attn2.to_v_img",
+ "cross_attn.norm_k_img": "attn2.norm_k_img",
+ # After cross_attn -> attn2 rename, we need to rename the img keys
+ "attn2.to_k_img": "attn2.add_k_proj",
+ "attn2.to_v_img": "attn2.add_v_proj",
+ "attn2.norm_k_img": "attn2.norm_added_k",
+ # Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
+ # Motion encoder mappings
+ # The name mapping is complicated for the convolutional part so we handle that in its own function
+ "motion_encoder.enc.fc": "motion_encoder.motion_network",
+ "motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
+ # Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
+ "face_encoder.conv1_local.conv": "face_encoder.conv1_local",
+ "face_encoder.conv2.conv": "face_encoder.conv2",
+ "face_encoder.conv3.conv": "face_encoder.conv3",
+ # Face adapter mappings are handled in a separate function
+}
+
+S2V_TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ # Add attention component mappings
+ "self_attn.q": "attn1.to_q",
+ "self_attn.k": "attn1.to_k",
+ "self_attn.v": "attn1.to_v",
+ "self_attn.o": "attn1.to_out.0",
+ "self_attn.norm_q": "attn1.norm_q",
+ "self_attn.norm_k": "attn1.norm_k",
+ "cross_attn.q": "attn2.to_q",
+ "cross_attn.k": "attn2.to_k",
+ "cross_attn.v": "attn2.to_v",
+ "cross_attn.o": "attn2.to_out.0",
+ "cross_attn.norm_q": "attn2.norm_q",
+ "cross_attn.norm_k": "attn2.norm_k",
+ "attn2.to_k_img": "attn2.add_k_proj",
+ "attn2.to_v_img": "attn2.add_v_proj",
+ "attn2.norm_k_img": "attn2.norm_added_k",
+ # S2V-specific audio component mappings
+ "casual_audio_encoder.encoder.conv2.conv": "condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv",
+ "casual_audio_encoder.encoder.conv3.conv": "condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv",
+ "casual_audio_encoder.weights": "condition_embedder.causal_audio_encoder.weighted_avg.weights",
+ # Pose condition encoder mappings
+ "cond_encoder.weight": "condition_embedder.pose_embedder.weight",
+ "cond_encoder.bias": "condition_embedder.pose_embedder.bias",
+ "trainable_cond_mask": "trainable_condition_mask",
+ "patch_embedding": "motion_in.patch_embedding",
+ # Audio injector attention mappings - convert original q/k/v/o format to diffusers format
+ **{
+ f"audio_injector.injector.{i}.{src}": f"audio_injector.injector.{i}.{dst}"
+ for i in range(12)
+ for src, dst in [("q", "to_q"), ("k", "to_k"), ("v", "to_v"), ("o", "to_out.0")]
+ },
+}
+
+
+# TODO: Verify this and simplify if possible.
+def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
+ """
+ Convert all motion encoder weights for Animate model.
+
+ In the original model:
+ - All Linear layers in fc use EqualLinear
+ - All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
+ - Blur kernels are stored as buffers in Sequential modules
+ - ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
+
+ Conversion strategy:
+ 1. Drop .kernel buffers (blur kernels)
+ 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
+ """
+ # Skip if not a weight, bias, or kernel
+ if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
+ return
+
+ # Handle Blur kernel buffers from original implementation.
+ # After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
+ # Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
+ if ".kernel" in key and "motion_encoder" in key:
+ # Remove unexpected blur kernel buffers to avoid strict load errors
+ state_dict.pop(key, None)
+ return
+
+ # Rename Sequential indices to named components in ConvLayer and ResBlock
+ if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
+ parts = key.split(".")
+
+ # Find the sequential index (digit) after convs or after conv1/conv2/skip
+ # Examples:
+ # - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
+ # - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
+ # - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
+ # - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
+ # - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
+ # - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
+ # - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
+ # - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
+ # - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
+ # - enc.net_app.convs.8 -> conv_out (final conv layer)
+
+ convs_idx = parts.index("convs") if "convs" in parts else -1
+ if convs_idx >= 0 and len(parts) - convs_idx >= 2:
+ bias = False
+ # The nn.Sequential index will always follow convs
+ sequential_idx = int(parts[convs_idx + 1])
+ if sequential_idx == 0:
+ if key.endswith(".weight"):
+ new_key = "motion_encoder.conv_in.weight"
+ elif key.endswith(".bias"):
+ new_key = "motion_encoder.conv_in.act_fn.bias"
+ bias = True
+ elif sequential_idx == final_conv_idx:
+ if key.endswith(".weight"):
+ new_key = "motion_encoder.conv_out.weight"
+ else:
+ # Intermediate .convs. layers, which get mapped to .res_blocks.
+ prefix = "motion_encoder.res_blocks."
+
+ layer_name = parts[convs_idx + 2]
+ if layer_name == "skip":
+ layer_name = "conv_skip"
+
+ if key.endswith(".weight"):
+ param_name = "weight"
+ elif key.endswith(".bias"):
+ param_name = "act_fn.bias"
+ bias = True
+
+ suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
+ suffix = ".".join(suffix_parts)
+ new_key = prefix + suffix
+
+ param = state_dict.pop(key)
+ if bias:
+ param = param.squeeze()
+ state_dict[new_key] = param
+ return
+ return
+ return
+
+
+def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
+ """
+ Convert face adapter weights for the Animate model.
+
+ The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
+ """
+ # Skip if not a weight or bias
+ if ".weight" not in key and ".bias" not in key:
+ return
+
+ prefix = "face_adapter."
+ if ".fuser_blocks." in key:
+ parts = key.split(".")
+
+ module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
+ if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
+ block_idx = parts[module_list_idx + 1]
+ layer_name = parts[module_list_idx + 2]
+ param_name = parts[module_list_idx + 3]
+
+ if layer_name == "linear1_kv":
+ layer_name_k = "to_k"
+ layer_name_v = "to_v"
+
+ suffix_k = ".".join([block_idx, layer_name_k, param_name])
+ suffix_v = ".".join([block_idx, layer_name_v, param_name])
+ new_key_k = prefix + suffix_k
+ new_key_v = prefix + suffix_v
+
+ kv_proj = state_dict.pop(key)
+ k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
+ state_dict[new_key_k] = k_proj
+ state_dict[new_key_v] = v_proj
+ return
+ else:
+ if layer_name == "q_norm":
+ new_layer_name = "norm_q"
+ elif layer_name == "k_norm":
+ new_layer_name = "norm_k"
+ elif layer_name == "linear1_q":
+ new_layer_name = "to_q"
+ elif layer_name == "linear2":
+ new_layer_name = "to_out"
+
+ suffix_parts = [block_idx, new_layer_name, param_name]
+ suffix = ".".join(suffix_parts)
+ new_key = prefix + suffix
+ state_dict[new_key] = state_dict.pop(key)
+ return
+ return
+
+
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "motion_encoder": convert_animate_motion_encoder_weights,
+ "face_adapter": convert_animate_face_adapter_weights,
+}
+S2V_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -364,6 +623,67 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-Animate-14B":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-Animate-14B",
+ "diffusers_config": {
+ "image_dim": 1280,
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": (1, 2, 2),
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "rope_max_seq_len": 1024,
+ "pos_embed_seq_len": None,
+ "motion_encoder_size": 512, # Start of Wan Animate-specific configs
+ "motion_style_dim": 512,
+ "motion_dim": 20,
+ "motion_encoder_dim": 512,
+ "face_encoder_hidden_dim": 1024,
+ "face_encoder_num_heads": 4,
+ "inject_face_latents_blocks": 5,
+ },
+ }
+ RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-S2V-14B":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-S2V-14B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": (1, 2, 2),
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "audio_dim": 1024,
+ "audio_inject_layers": [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
+ "enable_adain": True,
+ "adain_mode": "attn_norm",
+ "pose_dim": 16,
+ "enable_framepack": True,
+ "framepack_drop_mode": "padd",
+ "add_last_motion": True,
+ "zero_timestep": True,
+ },
+ }
+ RENAME_DICT = S2V_TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = S2V_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
@@ -380,10 +700,14 @@ def convert_transformer(model_type: str, stage: str = None):
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
- if "VACE" not in model_type:
- transformer = WanTransformer3DModel.from_config(diffusers_config)
- else:
+ if "S2V" in model_type:
+ transformer = WanS2VTransformer3DModel.from_config(diffusers_config)
+ elif "Animate" in model_type:
+ transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
+ elif "VACE" in model_type:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
+ else:
+ transformer = WanTransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -397,7 +721,12 @@ def convert_transformer(model_type: str, stage: str = None):
continue
handler_fn_inplace(key, original_state_dict)
+ # Load state dict into the meta model, which will materialize the tensors
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+
+ # Move to CPU to ensure all tensors are materialized
+ transformer = transformer.to("cpu")
+
return transformer
@@ -926,7 +1255,7 @@ def get_args():
if __name__ == "__main__":
args = get_args()
- if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
+ if "Wan2.2" in args.model_type and not any(tag in args.model_type for tag in ("TI2V", "Animate", "S2V")):
transformer = convert_transformer(args.model_type, stage="high_noise_model")
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
else:
@@ -942,7 +1271,7 @@ def get_args():
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
if "FLF2V" in args.model_type:
flow_shift = 16.0
- elif "TI2V" in args.model_type:
+ elif any(tag in args.model_type for tag in ("TI2V", "Animate", "S2V")):
flow_shift = 5.0
else:
flow_shift = 3.0
@@ -954,6 +1283,8 @@ def get_args():
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
+ if transformer_2 is not None:
+ transformer_2.to(dtype)
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
pipe = WanImageToVideoPipeline(
@@ -1016,6 +1347,36 @@ def get_args():
vae=vae,
scheduler=scheduler,
)
+ elif "Animate" in args.model_type:
+ image_encoder = CLIPVisionModel.from_pretrained(
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
+ )
+ image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ pipe = WanAnimatePipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
+ elif "S2V" in args.model_type:
+ audio_encoder = Wav2Vec2ForCTC.from_pretrained(
+ "Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english"
+ )
+ audio_processor = Wav2Vec2Processor.from_pretrained(
+ "Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english"
+ )
+ pipe = WanSpeechToVideoPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ audio_encoder=audio_encoder,
+ audio_processor=audio_processor,
+ )
else:
pipe = WanPipeline(
transformer=transformer,
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index a5040bd28394..965f991b5783 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -268,6 +268,8 @@
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
+ "WanAnimateTransformer3DModel",
+ "WanS2VTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
@@ -544,11 +546,13 @@
"QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
+ "SanaImageToVideoPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
"SanaVideoPipeline",
+ "SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -636,8 +640,10 @@
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
+ "WanAnimatePipeline",
"WanImageToVideoPipeline",
"WanPipeline",
+ "WanSpeechToVideoPipeline",
"WanVACEPipeline",
"WanVideoToVideoPipeline",
"WuerstchenCombinedPipeline",
@@ -977,6 +983,8 @@
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
+ WanAnimateTransformer3DModel,
+ WanS2VTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
@@ -1224,6 +1232,7 @@
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
+ SanaImageToVideoPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaSprintImg2ImgPipeline,
@@ -1315,8 +1324,10 @@
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
+ WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
+ WanSpeechToVideoPipeline,
WanVACEPipeline,
WanVideoToVideoPipeline,
WuerstchenCombinedPipeline,
diff --git a/src/diffusers/audio_processor.py b/src/diffusers/audio_processor.py
new file mode 100644
index 000000000000..491aacf530aa
--- /dev/null
+++ b/src/diffusers/audio_processor.py
@@ -0,0 +1,71 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Union
+
+import numpy as np
+import torch
+
+
+PipelineAudioInput = Union[
+ np.ndarray,
+ torch.Tensor,
+ List[np.ndarray],
+ List[torch.Tensor],
+]
+
+
+def is_valid_audio(audio) -> bool:
+ r"""
+ Checks if the input is a valid audio.
+
+ A valid audio can be:
+ - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
+
+ Args:
+ audio (`Union[np.ndarray, torch.Tensor]`):
+ The audio to validate. It can be a NumPy array or a torch tensor.
+
+ Returns:
+ `bool`:
+ `True` if the input is a valid audio, `False` otherwise.
+ """
+ return isinstance(audio, (np.ndarray, torch.Tensor)) and audio.ndim in (2, 3)
+
+
+def is_valid_audio_audiolist(audios):
+ r"""
+ Checks if the input is a valid audio or list of audios.
+
+ The input can be one of the following formats:
+ - A 4D tensor or numpy array (batch of audios).
+ - A valid single audio: `np.ndarray` or `torch.Tensor`.
+ - A list of valid audios.
+
+ Args:
+ audios (`Union[np.ndarray, torch.Tensor, List]`):
+ The audio(s) to check. Can be a batch of audios (4D tensor/array), a single audio, or a list of valid
+ audios.
+
+ Returns:
+ `bool`:
+ `True` if the input is valid, `False` otherwise.
+ """
+ if isinstance(audios, (np.ndarray, torch.Tensor)) and audios.ndim == 4:
+ return True
+ elif is_valid_audio(audios):
+ return True
+ elif isinstance(audios, list):
+ return all(is_valid_audio(audio) for audio in audios)
+ return False
diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py
index 52cb0ce34980..6c328328fc3b 100644
--- a/src/diffusers/guiders/guider_utils.py
+++ b/src/diffusers/guiders/guider_utils.py
@@ -373,7 +373,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index 067d876ffcd8..bb88f728a401 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -409,7 +409,7 @@ def _resize_and_fill(
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
- resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
@@ -437,10 +437,11 @@ def _resize_and_crop(
image: PIL.Image.Image,
width: int,
height: int,
+ resize_type: str = "fit_within",
+ crop_type: str = "paste_center",
) -> PIL.Image.Image:
r"""
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
- the image within the dimensions, cropping the excess.
+ Resize and crop the image using different strategies.
Args:
image (`PIL.Image.Image`):
@@ -449,28 +450,55 @@ def _resize_and_crop(
The width to resize the image to.
height (`int`):
The height to resize the image to.
+ resize_type (`str`, optional):
+ How to resize the image. Options:
+ - "fit_within": Resize to fit within dimensions, maintaining aspect ratio (default)
+ - "min_dimension": Resize so smaller dimension becomes min(width, height)
+ crop_type (`str`, optional):
+ How to handle the final cropping/positioning. Options:
+ - "paste_center": Paste resized image on centered canvas, pad with black (default)
+ - "center_crop": Center crop to exact dimensions, pad with black if needed
Returns:
`PIL.Image.Image`:
The resized and cropped image.
"""
- ratio = width / height
- src_ratio = image.width / image.height
- src_w = width if ratio > src_ratio else image.width * height // image.height
- src_h = height if ratio <= src_ratio else image.height * width // image.width
+ if resize_type == "fit_within":
+ # Resize to fit within dimensions
+ ratio = width / height
+ src_ratio = image.width / image.height
- resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
- res = Image.new("RGB", (width, height))
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
- return res
+ src_w = width if ratio > src_ratio else image.width * height // image.height
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
+
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
+ elif resize_type == "min_dimension":
+ # # Resize so smaller dimension becomes min(width, height)
+ from torchvision.transforms import Resize
+
+ resized = Resize(min(height, width))(image)
+ else:
+ raise ValueError(f"Unknown resize_type: {resize_type}")
+
+ if crop_type == "paste_center":
+ # Paste on canvas, center position
+ res = Image.new("RGB", (width, height), color=0) # Black background
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+ return res
+ elif crop_type == "center_crop":
+ from torchvision.transforms import CenterCrop
+
+ return CenterCrop((height, width))(resized)
+ else:
+ raise ValueError(f"Unknown crop_type: {crop_type}")
def resize(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: int,
width: int,
- resize_mode: str = "default", # "default", "fill", "crop"
+ resize_mode: str = "default", # "default", "fill", "crop", "resize_min_center_crop"
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Resize image.
@@ -483,13 +511,16 @@ def resize(
width (`int`):
The width to resize to.
resize_mode (`str`, *optional*, defaults to `default`):
- The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
- within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
- will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
- then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
- the image to fit within the specified width and height, maintaining the aspect ratio, and then center
- the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
- supported for PIL image input.
+ The resize mode to use, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If
+ `default`, will resize the image to fit within the specified width and height, and it may not
+ maintaining the original aspect ratio. If `fill`, will resize the image to fit within the specified
+ width and height, maintaining the aspect ratio, and then center the image within the dimensions,
+ filling empty with data from image. If `crop`, will resize the image to fit within the specified width
+ and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the
+ excess. If `resize_min_center_crop`, will resize the image so that the smaller dimension becomes
+ min(width, height), then center crop to exact target dimensions (matches Wan2.2-S2V preprocessing).
+ Note that resize_mode `fill`, `crop`, and `resize_min_center_crop` are only supported for PIL image
+ input.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
@@ -508,6 +539,10 @@ def resize(
image = self._resize_and_fill(image, width, height)
elif resize_mode == "crop":
image = self._resize_and_crop(image, width, height)
+ elif resize_mode == "resize_min_center_crop":
+ image = self._resize_and_crop(
+ image, width, height, resize_type="min_dimension", crop_type="center_crop"
+ )
else:
raise ValueError(f"resize_mode {resize_mode} is not supported")
@@ -615,7 +650,7 @@ def preprocess(
image: PipelineImageInput,
height: Optional[int] = None,
width: Optional[int] = None,
- resize_mode: str = "default", # "default", "fill", "crop"
+ resize_mode: str = "default", # "default", "fill", "crop", "resize_min_center_crop"
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> torch.Tensor:
"""
@@ -631,13 +666,15 @@ def preprocess(
width (`int`, *optional*):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
- The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
- the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
- resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
- center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
- image to fit within the specified width and height, maintaining the aspect ratio, and then center the
- image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
- supported for PIL image input.
+ The resize mode, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If `default`,
+ will resize the image to fit within the specified width and height, and it may not maintaining the
+ original aspect ratio. If `fill`, will resize the image to fit within the specified width and height,
+ maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data
+ from image. If `crop`, will resize the image to fit within the specified width and height, maintaining
+ the aspect ratio, and then center the image within the dimensions, cropping the excess. If
+ `resize_min_center_crop`, will resize the image so that the smaller dimension becomes min(width,
+ height), then center crop to exact target dimensions (matches Wan2.2 preprocessing). Note that
+ resize_mode `fill`, `crop`, and `resize_min_center_crop` are only supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index e97ab8bd1d2a..89d7debb34b1 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -108,6 +108,8 @@
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
+ _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
+ _import_structure["transformers.transformer_wan_s2v"] = ["WanS2VTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
@@ -214,6 +216,8 @@
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
+ WanAnimateTransformer3DModel,
+ WanS2VTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
index 92a4a6a59936..8504504981a3 100644
--- a/src/diffusers/models/attention_dispatch.py
+++ b/src/diffusers/models/attention_dispatch.py
@@ -16,6 +16,7 @@
import functools
import inspect
import math
+from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
@@ -42,7 +43,7 @@
is_xformers_available,
is_xformers_version,
)
-from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
+from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
if TYPE_CHECKING:
@@ -82,24 +83,11 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None
-
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
-if DIFFUSERS_ENABLE_HUB_KERNELS:
- if not is_kernels_available():
- raise ImportError(
- "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
- )
- from ..utils.kernels_utils import _get_fa3_from_hub
-
- flash_attn_interface_hub = _get_fa3_from_hub()
- flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
-else:
- flash_attn_3_func_hub = None
-
if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
@@ -261,6 +249,25 @@ def _is_context_parallel_available(
return supports_context_parallel
+@dataclass
+class _HubKernelConfig:
+ """Configuration for downloading and using a hub-based attention kernel."""
+
+ repo_id: str
+ function_attr: str
+ revision: Optional[str] = None
+ kernel_fn: Optional[Callable] = None
+
+
+# Registry for hub-based attention kernels
+_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
+ # TODO: temporary revision for now. Remove when merged upstream into `main`.
+ AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
+ repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
+ )
+}
+
+
@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
@@ -415,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
- if not DIFFUSERS_ENABLE_HUB_KERNELS:
- raise RuntimeError(
- f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
- )
if not is_kernels_available():
raise RuntimeError(
- f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
+ f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend == AttentionBackendName.AITER:
@@ -571,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx
+# ===== Helpers for downloading kernels =====
+def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
+ if backend not in _HUB_KERNELS_REGISTRY:
+ return
+ config = _HUB_KERNELS_REGISTRY[backend]
+
+ if config.kernel_fn is not None:
+ return
+
+ try:
+ from kernels import get_kernel
+
+ kernel_module = get_kernel(config.repo_id, revision=config.revision)
+ kernel_func = getattr(kernel_module, config.function_attr)
+
+ # Cache the downloaded kernel function in the config object
+ config.kernel_fn = kernel_func
+
+ except Exception as e:
+ logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
+ raise
+
+
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
@@ -1418,7 +1444,8 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- out = flash_attn_3_func_hub(
+ func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
+ out = func(
q=query,
k=key,
v=value,
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
index 0aadbad9f4de..618801dfb605 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
@@ -16,7 +16,7 @@
# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
# For more information about the Wan VAE, please refer to:
# - GitHub: https://github.com/Wan-Video/Wan2.1
-# - arXiv: https://arxiv.org/abs/2503.20314
+# - Paper: https://huggingface.co/papers/2503.20314
from typing import List, Optional, Tuple, Union
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
index 5b4b74543ae3..b0b2960aaf18 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
@@ -971,7 +971,7 @@ def __init__(
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
- dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
+ dim_mult: List[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index e4a8f30e721f..f06822c741ca 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None:
attention as backend.
"""
from .attention import AttentionModuleMixin
- from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
+ from .attention_dispatch import (
+ AttentionBackendName,
+ _check_attention_backend_requirements,
+ _maybe_download_kernel_for_backend,
+ )
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
@@ -606,8 +610,10 @@ def set_attention_backend(self, backend: str) -> None:
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
+
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
+ _maybe_download_kernel_for_backend(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 66daf56e23b2..2286c2c120b3 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -42,4 +42,6 @@
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
+ from .transformer_wan_animate import WanAnimateTransformer3DModel
+ from .transformer_wan_s2v import WanS2VTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py
index 9b2664b9cb26..ccbc83ffca03 100644
--- a/src/diffusers/models/transformers/transformer_prx.py
+++ b/src/diffusers/models/transformers/transformer_prx.py
@@ -275,7 +275,12 @@ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+
+ is_mps = pos.device.type == "mps"
+ is_npu = pos.device.type == "npu"
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+
+ scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py
index 424d9ff9d360..a4f90342631a 100644
--- a/src/diffusers/models/transformers/transformer_sana_video.py
+++ b/src/diffusers/models/transformers/transformer_sana_video.py
@@ -188,6 +188,11 @@ def __init__(
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
+
+ self.t_dim = t_dim
+ self.h_dim = h_dim
+ self.w_dim = w_dim
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
@@ -213,11 +218,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
- split_sizes = [
- self.attention_head_dim - 2 * (self.attention_head_dim // 3),
- self.attention_head_dim // 3,
- self.attention_head_dim // 3,
- ]
+ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
@@ -236,7 +237,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return freqs_cos, freqs_sin
-# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
@@ -246,7 +246,7 @@ def forward(
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
- shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
+ shift, scale = (scale_shift_table[None, None] + temb[:, :, None].to(scale_shift_table.device)).unbind(dim=2)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states
@@ -422,8 +422,8 @@ def forward(
# 1. Modulation
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
- ).chunk(6, dim=1)
+ self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1)
+ ).unbind(dim=2)
# 2. Self Attention
norm_hidden_states = self.norm1(hidden_states)
@@ -634,13 +634,16 @@ def forward(
if guidance is not None:
timestep, embedded_timestep = self.time_embed(
- timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
+ timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype
)
else:
timestep, embedded_timestep = self.time_embed(
- timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
+ timestep = timestep.view(batch_size, -1, timestep.size(-1))
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py
new file mode 100644
index 000000000000..6a47a67385a3
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_wan_animate.py
@@ -0,0 +1,1298 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = {
+ "4": 512,
+ "8": 512,
+ "16": 512,
+ "32": 512,
+ "64": 256,
+ "128": 128,
+ "256": 64,
+ "512": 32,
+ "1024": 16,
+}
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
+def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
+def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+class FusedLeakyReLU(nn.Module):
+ """
+ Fused LeakyRelu with scale factor and channel-wise bias.
+ """
+
+ def __init__(self, negative_slope: float = 0.2, scale: float = 2**0.5, bias_channels: Optional[int] = None):
+ super().__init__()
+ self.negative_slope = negative_slope
+ self.scale = scale
+ self.channels = bias_channels
+
+ if self.channels is not None:
+ self.bias = nn.Parameter(
+ torch.zeros(
+ self.channels,
+ )
+ )
+ else:
+ self.bias = None
+
+ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ if self.bias is not None:
+ # Expand self.bias to have all singleton dims except at self.channel_dim
+ expanded_shape = [1] * x.ndim
+ expanded_shape[channel_dim] = self.bias.shape[0]
+ bias = self.bias.reshape(*expanded_shape)
+ x = x + bias
+ return F.leaky_relu(x, self.negative_slope) * self.scale
+
+
+class MotionConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ bias: bool = True,
+ blur_kernel: Optional[Tuple[int, ...]] = None,
+ blur_upsample_factor: int = 1,
+ use_activation: bool = True,
+ ):
+ super().__init__()
+ self.use_activation = use_activation
+ self.in_channels = in_channels
+
+ # Handle blurring (applying a FIR filter with the given kernel) if available
+ self.blur = False
+ if blur_kernel is not None:
+ p = (len(blur_kernel) - stride) + (kernel_size - 1)
+ self.blur_padding = ((p + 1) // 2, p // 2)
+
+ kernel = torch.tensor(blur_kernel)
+ # Convert kernel to 2D if necessary
+ if kernel.ndim == 1:
+ kernel = kernel[None, :] * kernel[:, None]
+ # Normalize kernel
+ kernel = kernel / kernel.sum()
+ if blur_upsample_factor > 1:
+ kernel = kernel * (blur_upsample_factor**2)
+ self.register_buffer("blur_kernel", kernel, persistent=False)
+ self.blur = True
+
+ # Main Conv2d parameters (with scale factor)
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+
+ self.stride = stride
+ self.padding = padding
+
+ # If using an activation function, the bias will be fused into the activation
+ if bias and not self.use_activation:
+ self.bias = nn.Parameter(torch.zeros(out_channels))
+ else:
+ self.bias = None
+
+ if self.use_activation:
+ self.act_fn = FusedLeakyReLU(bias_channels=out_channels)
+ else:
+ self.act_fn = None
+
+ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ # Apply blur if using
+ if self.blur:
+ # NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
+ # set to 1, which should be equivalent to a 2D convolution
+ expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
+ x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
+
+ # Main Conv2D with scaling
+ x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
+
+ # Activation with fused bias, if using
+ if self.use_activation:
+ x = self.act_fn(x, channel_dim=channel_dim)
+ return x
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
+ f" kernel_size={self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
+ )
+
+
+class MotionLinear(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ bias: bool = True,
+ use_activation: bool = False,
+ ):
+ super().__init__()
+ self.use_activation = use_activation
+
+ # Linear weight with scale factor
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
+ self.scale = 1 / math.sqrt(in_dim)
+
+ # If an activation is present, the bias will be fused to it
+ if bias and not self.use_activation:
+ self.bias = nn.Parameter(torch.zeros(out_dim))
+ else:
+ self.bias = None
+
+ if self.use_activation:
+ self.act_fn = FusedLeakyReLU(bias_channels=out_dim)
+ else:
+ self.act_fn = None
+
+ def forward(self, input: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ out = F.linear(input, self.weight * self.scale, bias=self.bias)
+ if self.use_activation:
+ out = self.act_fn(out, channel_dim=channel_dim)
+ return out
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]},"
+ f" bias={self.bias is not None})"
+ )
+
+
+class MotionEncoderResBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ kernel_size_skip: int = 1,
+ blur_kernel: Tuple[int, ...] = (1, 3, 3, 1),
+ downsample_factor: int = 2,
+ ):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ # 3 x 3 Conv + fused leaky ReLU
+ self.conv1 = MotionConv2d(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ use_activation=True,
+ )
+
+ # 3 x 3 Conv that downsamples 2x + fused leaky ReLU
+ self.conv2 = MotionConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=self.downsample_factor,
+ padding=0,
+ blur_kernel=blur_kernel,
+ use_activation=True,
+ )
+
+ # 1 x 1 Conv that downsamples 2x in skip connection
+ self.conv_skip = MotionConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size_skip,
+ stride=self.downsample_factor,
+ padding=0,
+ bias=False,
+ blur_kernel=blur_kernel,
+ use_activation=False,
+ )
+
+ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ x_out = self.conv1(x, channel_dim)
+ x_out = self.conv2(x_out, channel_dim)
+
+ x_skip = self.conv_skip(x, channel_dim)
+
+ x_out = (x_out + x_skip) / math.sqrt(2)
+ return x_out
+
+
+class WanAnimateMotionEncoder(nn.Module):
+ def __init__(
+ self,
+ size: int = 512,
+ style_dim: int = 512,
+ motion_dim: int = 20,
+ out_dim: int = 512,
+ motion_blocks: int = 5,
+ channels: Optional[Dict[str, int]] = None,
+ ):
+ super().__init__()
+ self.size = size
+
+ # Appearance encoder: conv layers
+ if channels is None:
+ channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES
+
+ self.conv_in = MotionConv2d(3, channels[str(size)], 1, use_activation=True)
+
+ self.res_blocks = nn.ModuleList()
+ in_channels = channels[str(size)]
+ log_size = int(math.log(size, 2))
+ for i in range(log_size, 2, -1):
+ out_channels = channels[str(2 ** (i - 1))]
+ self.res_blocks.append(MotionEncoderResBlock(in_channels, out_channels))
+ in_channels = out_channels
+
+ self.conv_out = MotionConv2d(in_channels, style_dim, 4, padding=0, bias=False, use_activation=False)
+
+ # Motion encoder: linear layers
+ # NOTE: there are no activations in between the linear layers here, which is weird but I believe matches the
+ # original code.
+ linears = [MotionLinear(style_dim, style_dim) for _ in range(motion_blocks - 1)]
+ linears.append(MotionLinear(style_dim, motion_dim))
+ self.motion_network = nn.ModuleList(linears)
+
+ self.motion_synthesis_weight = nn.Parameter(torch.randn(out_dim, motion_dim))
+
+ def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ if (face_image.shape[-2] != self.size) or (face_image.shape[-1] != self.size):
+ raise ValueError(
+ f"Face pixel values has resolution ({face_image.shape[-1]}, {face_image.shape[-2]}) but is expected"
+ f" to have resolution ({self.size}, {self.size})"
+ )
+
+ # Appearance encoding through convs
+ face_image = self.conv_in(face_image, channel_dim)
+ for block in self.res_blocks:
+ face_image = block(face_image, channel_dim)
+ face_image = self.conv_out(face_image, channel_dim)
+ motion_feat = face_image.squeeze(-1).squeeze(-1)
+
+ # Motion feature extraction
+ for linear_layer in self.motion_network:
+ motion_feat = linear_layer(motion_feat, channel_dim=channel_dim)
+
+ # Motion synthesis via Linear Motion Decomposition
+ weight = self.motion_synthesis_weight + 1e-8
+ # Upcast the QR orthogonalization operation to FP32
+ original_motion_dtype = motion_feat.dtype
+ motion_feat = motion_feat.to(torch.float32)
+ weight = weight.to(torch.float32)
+
+ Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)
+
+ motion_feat_diag = torch.diag_embed(motion_feat) # Alpha, diagonal matrix
+ motion_decomposition = torch.matmul(motion_feat_diag, Q.T)
+ motion_vec = torch.sum(motion_decomposition, dim=1)
+
+ motion_vec = motion_vec.to(dtype=original_motion_dtype)
+
+ return motion_vec
+
+
+class WanAnimateFaceEncoder(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ hidden_dim: int = 1024,
+ num_heads: int = 4,
+ kernel_size: int = 3,
+ eps: float = 1e-6,
+ pad_mode: str = "replicate",
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.time_causal_padding = (kernel_size - 1, 0)
+ self.pad_mode = pad_mode
+
+ self.act = nn.SiLU()
+
+ self.conv1_local = nn.Conv1d(in_dim, hidden_dim * num_heads, kernel_size=kernel_size, stride=1)
+ self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2)
+ self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2)
+
+ self.norm1 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
+ self.norm2 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
+ self.norm3 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
+
+ self.out_proj = nn.Linear(hidden_dim, out_dim)
+
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, out_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size = x.shape[0]
+
+ # Reshape to channels-first to apply causal Conv1d over frame dim
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ x = self.conv1_local(x) # [B, C, T_padded] --> [B, N * C, T]
+ x = x.unflatten(1, (self.num_heads, -1)).flatten(0, 1) # [B, N * C, T] --> [B * N, C, T]
+ # Reshape back to channels-last to apply LayerNorm over channel dim
+ x = x.permute(0, 2, 1)
+ x = self.norm1(x)
+ x = self.act(x)
+
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ x = self.conv2(x)
+ x = x.permute(0, 2, 1)
+ x = self.norm2(x)
+ x = self.act(x)
+
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ x = self.conv3(x)
+ x = x.permute(0, 2, 1)
+ x = self.norm3(x)
+ x = self.act(x)
+
+ x = self.out_proj(x)
+ x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # [B * N, T, C_out] --> [B, T, N, C_out]
+
+ padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1).to(device=x.device)
+ x = torch.cat([x, padding], dim=-2) # [B, T, N, C_out] --> [B, T, N + 1, C_out]
+
+ return x
+
+
+class WanAnimateFaceBlockAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or"
+ f" higher."
+ )
+
+ def __call__(
+ self,
+ attn: "WanAnimateFaceBlockCrossAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # encoder_hidden_states corresponds to the motion vec
+ # attention_mask corresponds to the motion mask (if any)
+ hidden_states = attn.pre_norm_q(hidden_states)
+ encoder_hidden_states = attn.pre_norm_kv(encoder_hidden_states)
+
+ # B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
+ B, T, N, C = encoder_hidden_states.shape
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
+ key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
+ value = value.view(B, T, N, attn.heads, -1)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ # NOTE: the below line (which follows the official code) means that in practice, the number of frames T in
+ # encoder_hidden_states (the motion vector after applying the face encoder) must evenly divide the
+ # post-patchify sequence length S of the transformer hidden_states. Is it possible to remove this dependency?
+ query = query.unflatten(1, (T, -1)).flatten(0, 1) # [B, S, H, D] --> [B * T, S / T, H, D]
+ key = key.flatten(0, 1) # [B, T, N, H, D_kv] --> [B * T, N, H, D_kv]
+ value = value.flatten(0, 1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+ hidden_states = hidden_states.unflatten(0, (B, T)).flatten(1, 2)
+
+ hidden_states = attn.to_out(hidden_states)
+
+ if attention_mask is not None:
+ # NOTE: attention_mask is assumed to be a multiplicative mask
+ attention_mask = attention_mask.flatten(start_dim=1)
+ hidden_states = hidden_states * attention_mask
+
+ return hidden_states
+
+
+class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
+ """
+ Temporally-aligned cross attention with the face motion signal in the Wan Animate Face Blocks.
+ """
+
+ _default_processor_cls = WanAnimateFaceBlockAttnProcessor
+ _available_processors = [WanAnimateFaceBlockAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-6,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ ):
+ super().__init__()
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.cross_attention_head_dim = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ # 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
+ # NOTE: this is not used in "vanilla" WanAttention
+ self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False)
+ self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
+
+ # 2. QKV and Output Projections
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
+
+ # 3. QK Norm
+ # NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True)
+
+ # 4. Set attention processor
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask)
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
+class WanAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: "WanAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states_img = hidden_states_img.flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttention
+class WanAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = WanAttnProcessor
+ _available_processors = [WanAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ is_cross_attention=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding
+class WanImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
+class WanTimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ pos_embed_seq_len: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ timestep_seq_len: Optional[int] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+ if timestep_seq_len is not None:
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
+class WanRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+
+ self.t_dim = t_dim
+ self.h_dim = h_dim
+ self.w_dim = w_dim
+
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanTransformerBlock
+class WanTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=WanAttnProcessor(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ cross_attention_dim_head=dim // num_heads,
+ processor=WanAttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ if temb.ndim == 4:
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(0) + temb.float()
+ ).chunk(6, dim=2)
+ # batch_size, seq_len, 1, inner_dim
+ shift_msa = shift_msa.squeeze(2)
+ scale_msa = scale_msa.squeeze(2)
+ gate_msa = gate_msa.squeeze(2)
+ c_shift_msa = c_shift_msa.squeeze(2)
+ c_scale_msa = c_scale_msa.squeeze(2)
+ c_gate_msa = c_gate_msa.squeeze(2)
+ else:
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class WanAnimateTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the WanAnimate model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ image_dim (`int`, *optional*, defaults to `1280`):
+ The number of channels to use for the image embedding. If `None`, no projection is used.
+ added_kv_proj_dim (`int`, *optional*, defaults to `5120`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["WanTransformerBlock", "MotionEncoderResBlock"]
+ _keep_in_fp32_modules = [
+ "time_embedder",
+ "scale_shift_table",
+ "norm1",
+ "norm2",
+ "norm3",
+ "motion_synthesis_weight",
+ ]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["WanTransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: Optional[int] = 36,
+ latent_channels: Optional[int] = 16,
+ out_channels: Optional[int] = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = 1280,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
+ motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, # Start of Wan Animate-specific args
+ motion_encoder_size: int = 512,
+ motion_style_dim: int = 512,
+ motion_dim: int = 20,
+ motion_encoder_dim: int = 512,
+ face_encoder_hidden_dim: int = 1024,
+ face_encoder_num_heads: int = 4,
+ inject_face_latents_blocks: int = 5,
+ motion_encoder_batch_size: int = 8,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ # Allow either only in_channels or only latent_channels to be set for convenience
+ if in_channels is None and latent_channels is not None:
+ in_channels = 2 * latent_channels + 4
+ elif in_channels is not None and latent_channels is None:
+ latent_channels = (in_channels - 4) // 2
+ elif in_channels is not None and latent_channels is not None:
+ # TODO: should this always be true?
+ assert in_channels == 2 * latent_channels + 4, "in_channels should be 2 * latent_channels + 4"
+ else:
+ raise ValueError("At least one of `in_channels` and `latent_channels` must be supplied.")
+ out_channels = out_channels or latent_channels
+
+ # 1. Patch & position embedding
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+ self.pose_patch_embedding = nn.Conv3d(latent_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ self.condition_embedder = WanTimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ )
+
+ # Motion encoder
+ self.motion_encoder = WanAnimateMotionEncoder(
+ size=motion_encoder_size,
+ style_dim=motion_style_dim,
+ motion_dim=motion_dim,
+ out_dim=motion_encoder_dim,
+ channels=motion_encoder_channel_sizes,
+ )
+
+ # Face encoder
+ self.face_encoder = WanAnimateFaceEncoder(
+ in_dim=motion_encoder_dim,
+ out_dim=inner_dim,
+ hidden_dim=face_encoder_hidden_dim,
+ num_heads=face_encoder_num_heads,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ WanTransformerBlock(
+ dim=inner_dim,
+ ffn_dim=ffn_dim,
+ num_heads=num_attention_heads,
+ qk_norm=qk_norm,
+ cross_attn_norm=cross_attn_norm,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.face_adapter = nn.ModuleList(
+ [
+ WanAnimateFaceBlockCrossAttention(
+ dim=inner_dim,
+ heads=num_attention_heads,
+ dim_head=inner_dim // num_attention_heads,
+ eps=eps,
+ cross_attention_dim_head=inner_dim // num_attention_heads,
+ processor=WanAnimateFaceBlockAttnProcessor(),
+ )
+ for _ in range(num_layers // inject_face_latents_blocks)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ pose_hidden_states: Optional[torch.Tensor] = None,
+ face_pixel_values: Optional[torch.Tensor] = None,
+ motion_encode_batch_size: Optional[int] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Forward pass of Wan2.2-Animate transformer model.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(B, 2C + 4, T + 1, H, W)`):
+ Input noisy video latents of shape `(B, 2C + 4, T + 1, H, W)`, where B is the batch size, C is the
+ number of latent channels (16 for Wan VAE), T is the number of latent frames in an inference segment, H
+ is the latent height, and W is the latent width.
+ timestep: (`torch.LongTensor`):
+ The current timestep in the denoising loop.
+ encoder_hidden_states (`torch.Tensor`):
+ Text embeddings from the text encoder (umT5 for Wan Animate).
+ encoder_hidden_states_image (`torch.Tensor`):
+ CLIP visual features of the reference (character) image.
+ pose_hidden_states (`torch.Tensor` of shape `(B, C, T, H, W)`):
+ Pose video latents. TODO: description
+ face_pixel_values (`torch.Tensor` of shape `(B, C', S, H', W')`):
+ Face video in pixel space (not latent space). Typically C' = 3 and H' and W' are the height/width of
+ the face video in pixels. Here S is the inference segment length, usually set to 77.
+ motion_encode_batch_size (`int`, *optional*):
+ The batch size for batched encoding of the face video via the motion encoder. Will default to
+ `self.config.motion_encoder_batch_size` if not set.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return the output as a dict or tuple.
+ """
+
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # Check that shapes match up
+ if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]:
+ raise ValueError(
+ f"pose_hidden_states frame dim (dim 2) is {pose_hidden_states.shape[2]} but must be one less than the"
+ f" hidden_states's corresponding frame dim: {hidden_states.shape[2]}"
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ # 1. Rotary position embedding
+ rotary_emb = self.rope(hidden_states)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embedding(hidden_states)
+ pose_hidden_states = self.pose_patch_embedding(pose_hidden_states)
+ # Add pose embeddings to hidden states
+ hidden_states[:, :, 1:] = hidden_states[:, :, 1:] + pose_hidden_states
+ # Calling contiguous() here is important so that we don't recompile when performing regional compilation
+ hidden_states = hidden_states.flatten(2).transpose(1, 2).contiguous()
+
+ # 3. Condition embeddings (time, text, image)
+ # Wan Animate is based on Wan 2.1 and thus uses Wan 2.1's timestep logic
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=None
+ )
+
+ # batch_size, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Get motion features from the face video
+ # Motion vector computation from face pixel values
+ batch_size, channels, num_face_frames, height, width = face_pixel_values.shape
+ # Rearrange from (B, C, T, H, W) to (B*T, C, H, W)
+ face_pixel_values = face_pixel_values.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
+
+ # Extract motion features using motion encoder
+ # Perform batched motion encoder inference to allow trading off inference speed for memory usage
+ motion_encode_batch_size = motion_encode_batch_size or self.config.motion_encoder_batch_size
+ face_batches = torch.split(face_pixel_values, motion_encode_batch_size)
+ motion_vec_batches = []
+ for face_batch in face_batches:
+ motion_vec_batch = self.motion_encoder(face_batch)
+ motion_vec_batches.append(motion_vec_batch)
+ motion_vec = torch.cat(motion_vec_batches)
+ motion_vec = motion_vec.view(batch_size, num_face_frames, -1)
+
+ # Now get face features from the motion vector
+ motion_vec = self.face_encoder(motion_vec)
+
+ # Add padding at the beginning (prepend zeros)
+ pad_face = torch.zeros_like(motion_vec[:, :1])
+ motion_vec = torch.cat([pad_face, motion_vec], dim=1)
+
+ # 5. Transformer blocks with face adapter integration
+ for block_idx, block in enumerate(self.blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
+ )
+ else:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+
+ # Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...)
+ if block_idx % self.config.inject_face_latents_blocks == 0:
+ face_adapter_block_idx = block_idx // self.config.inject_face_latents_blocks
+ face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec)
+ # In case the face adapter and main transformer blocks are on different devices, which can happen when
+ # using model parallelism
+ face_adapter_output = face_adapter_output.to(device=hidden_states.device)
+ hidden_states = face_adapter_output + hidden_states
+
+ # 6. Output norm, projection & unpatchify
+ # batch_size, inner_dim
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ hidden_states_original_dtype = hidden_states.dtype
+ hidden_states = self.norm_out(hidden_states.float())
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+ hidden_states = (hidden_states * (1 + scale) + shift).to(dtype=hidden_states_original_dtype)
+
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py
new file mode 100644
index 000000000000..8440f4346968
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_wan_s2v.py
@@ -0,0 +1,1188 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin, get_parameter_dtype
+from ..normalization import AdaLayerNorm, FP32LayerNorm
+from .transformer_wan import (
+ WanAttention,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+class WanS2VAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "WanS2VAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: "WanAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ # dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
+ n = hidden_states.size(2)
+ # loop over samples
+ output = []
+ for i in range(hidden_states.size(0)):
+ s = hidden_states.size(1)
+ x_i = torch.view_as_complex(hidden_states[i, :s].to(torch.float64).reshape(s, n, -1, 2))
+ freqs_i = freqs[i, :s]
+ # apply rotary embedding
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
+ x_i = torch.cat([x_i, hidden_states[i, s:]])
+ # append to collection
+ output.append(x_i)
+ return torch.stack(output).type_as(hidden_states)
+
+ query = apply_rotary_emb(query, rotary_emb)
+ key = apply_rotary_emb(key, rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states_img = hidden_states_img.flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class WanS2VCausalConv1d(nn.Module):
+ def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
+ super().__init__()
+
+ self.pad_mode = pad_mode
+ self.time_causal_padding = (kernel_size - 1, 0) # T
+
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(x)
+
+
+class WanS2VCausalConvLayer(nn.Module):
+ """A layer that combines causal convolution, normalization, and activation in sequence."""
+
+ def __init__(
+ self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", eps=1e-6, **kwargs
+ ):
+ super().__init__()
+
+ self.conv = WanS2VCausalConv1d(chan_in, chan_out, kernel_size, stride, dilation, pad_mode, **kwargs)
+ self.norm = nn.LayerNorm(chan_out, elementwise_affine=False, eps=eps)
+ self.act = nn.SiLU()
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1)
+ x = self.conv(x)
+ x = x.permute(0, 2, 1)
+ x = self.norm(x)
+ x = self.act(x)
+ return x
+
+
+class WanS2VMotionEncoder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads: int, need_global: bool = True):
+ super().__init__()
+
+ self.num_attention_heads = num_attention_heads
+ self.need_global = need_global
+ self.conv1_local = WanS2VCausalConv1d(in_dim, hidden_dim // 4 * num_attention_heads, 3, stride=1)
+ if need_global:
+ self.conv1_global = WanS2VCausalConv1d(in_dim, hidden_dim // 4, 3, stride=1)
+ self.conv2 = WanS2VCausalConvLayer(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
+ self.conv3 = WanS2VCausalConvLayer(hidden_dim // 2, hidden_dim, 3, stride=2)
+
+ if need_global:
+ self.final_linear = nn.Linear(hidden_dim, hidden_dim)
+
+ self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6)
+ self.act = nn.SiLU()
+
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1)
+ residual = x.clone()
+ batch_size, num_channels, seq_len = x.shape
+ x = self.conv1_local(x)
+ x = x.unflatten(1, (self.num_attention_heads, -1)).permute(0, 1, 3, 2).flatten(0, 1)
+ x = self.norm1(x)
+ x = self.act(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3)
+ padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1)
+ x = torch.cat([x, padding], dim=-2)
+ x_local = x.clone()
+
+ if not self.need_global:
+ return x_local
+
+ x = self.conv1_global(residual)
+ x = x.permute(0, 2, 1)
+ x = self.norm1(x)
+ x = self.act(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.final_linear(x)
+ x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3)
+
+ return x, x_local
+
+
+class WeightedAveragelayer(nn.Module):
+ def __init__(self, num_layers):
+ super().__init__()
+ self.weights = torch.nn.Parameter(torch.ones((1, num_layers, 1, 1)) * 0.01)
+ self.act = torch.nn.SiLU()
+
+ def forward(self, features):
+ # features B * num_layers * dim * video_length
+ weights = self.act(self.weights)
+ weights_sum = weights.sum(dim=1, keepdims=True)
+ weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
+
+ return weighted_feat
+
+
+class CausalAudioEncoder(nn.Module):
+ def __init__(self, dim=5120, num_weighted_avg_layers=25, out_dim=2048, num_audio_token=4, need_global=False):
+ super().__init__()
+ self.weighted_avg = WeightedAveragelayer(num_weighted_avg_layers)
+ self.encoder = WanS2VMotionEncoder(
+ in_dim=dim, hidden_dim=out_dim, num_attention_heads=num_audio_token, need_global=need_global
+ )
+
+ def forward(self, features):
+ # features B * num_layers * dim * video_length
+ weighted_feat = self.weighted_avg(features)
+ weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
+ res = self.encoder(weighted_feat) # b f n dim
+
+ return res # b f n dim
+
+
+class AudioInjector(nn.Module):
+ def __init__(
+ self,
+ num_injection_layers,
+ inject_layers,
+ dim=2048,
+ num_heads=32,
+ enable_adain=False,
+ adain_mode="attn_norm",
+ adain_dim=2048,
+ eps=1e-6,
+ added_kv_proj_dim=None,
+ ):
+ super().__init__()
+ self.enable_adain = enable_adain
+ self.adain_mode = adain_mode
+ self.injected_block_id = dict(zip(inject_layers, range(num_injection_layers)))
+
+ # Cross-attention
+ self.injector = nn.ModuleList(
+ [
+ WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ cross_attention_dim_head=dim // num_heads,
+ processor=WanS2VAttnProcessor(),
+ )
+ for _ in range(num_injection_layers)
+ ]
+ )
+
+ self.injector_pre_norm_feat = nn.ModuleList(
+ [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(num_injection_layers)]
+ )
+ self.injector_pre_norm_vec = nn.ModuleList(
+ [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(num_injection_layers)]
+ )
+
+ if enable_adain:
+ self.injector_adain_layers = nn.ModuleList(
+ [
+ AdaLayerNorm(embedding_dim=adain_dim, output_dim=dim * 2, chunk_dim=1)
+ for _ in range(num_injection_layers)
+ ]
+ )
+ if adain_mode != "attn_norm":
+ self.injector_adain_output_layers = nn.ModuleList(
+ [nn.Linear(dim, dim) for _ in range(num_injection_layers)]
+ )
+
+ def forward(
+ self,
+ block_idx,
+ hidden_states,
+ original_sequence_length,
+ merged_audio_emb_num_frames,
+ attn_audio_emb,
+ audio_emb_global,
+ ):
+ audio_attn_id = self.injected_block_id[block_idx]
+
+ input_hidden_states = hidden_states[:, :original_sequence_length].clone() # B (F H W) C
+ input_hidden_states = input_hidden_states.unflatten(1, (merged_audio_emb_num_frames, -1)).flatten(0, 1)
+
+ if self.enable_adain and self.adain_mode == "attn_norm":
+ attn_hidden_states = self.injector_adain_layers[audio_attn_id](
+ input_hidden_states, temb=audio_emb_global[:, 0]
+ )
+ else:
+ attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
+
+ residual_out = self.injector[audio_attn_id](attn_hidden_states, attn_audio_emb, None, None)
+ residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2)
+ hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out
+
+ return hidden_states
+
+
+class FramePackMotioner(nn.Module):
+ def __init__(
+ self,
+ inner_dim=1024,
+ num_attention_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
+ zip_frame_buckets=[
+ 1,
+ 2,
+ 16,
+ ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
+ drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
+ patch_size=(1, 2, 2),
+ in_channels=16,
+ ):
+ super().__init__()
+ self.inner_dim = inner_dim
+ self.num_attention_heads = num_attention_heads
+ self.in_channels = in_channels
+ if (inner_dim % num_attention_heads) != 0 or (inner_dim // num_attention_heads) % 2 != 0:
+ raise ValueError(
+ "inner_dim must be divisible by num_attention_heads and inner_dim // num_attention_heads must be even"
+ )
+ self.drop_mode = drop_mode
+
+ self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
+ self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
+ self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
+ self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long)
+
+ self.rope = WanS2VRotaryPosEmbed(
+ inner_dim // num_attention_heads,
+ patch_size=patch_size,
+ max_seq_len=1024,
+ num_attention_heads=num_attention_heads,
+ )
+
+ def forward(self, motion_latents, add_last_motion=2):
+ latent_height, latent_width = motion_latents.shape[3], motion_latents.shape[4]
+ padd_latent = torch.zeros(
+ (motion_latents.shape[0], self.in_channels, self.zip_frame_buckets.sum(), latent_height, latent_width),
+ device=motion_latents.device,
+ dtype=motion_latents.dtype,
+ )
+ overlap_frame = min(padd_latent.shape[2], motion_latents.shape[2])
+ if overlap_frame > 0:
+ padd_latent[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
+
+ if add_last_motion < 2 and self.drop_mode != "drop":
+ zero_end_frame = self.zip_frame_buckets[: len(self.zip_frame_buckets) - add_last_motion - 1].sum()
+ padd_latent[:, :, -zero_end_frame:] = 0
+
+ clean_latents_4x, clean_latents_2x, clean_latents_post = padd_latent[
+ :, :, -self.zip_frame_buckets.sum() :, :, :
+ ].split(list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2, 1
+
+ # Patchify
+ clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
+ clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2)
+ clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2)
+
+ if add_last_motion < 2 and self.drop_mode == "drop":
+ clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post
+ clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x
+
+ motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
+
+ # RoPE
+ start_time_id = -(self.zip_frame_buckets[:1].sum())
+ end_time_id = start_time_id + self.zip_frame_buckets[0]
+ grid_sizes = (
+ []
+ if add_last_motion < 2 and self.drop_mode == "drop"
+ else [
+ [
+ torch.tensor([start_time_id, 0, 0]).unsqueeze(0),
+ torch.tensor([end_time_id, latent_height // 2, latent_width // 2]).unsqueeze(0),
+ torch.tensor([self.zip_frame_buckets[0], latent_height // 2, latent_width // 2]).unsqueeze(0),
+ ]
+ ]
+ )
+
+ start_time_id = -(self.zip_frame_buckets[:2].sum())
+ end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
+ grid_sizes_2x = (
+ []
+ if add_last_motion < 1 and self.drop_mode == "drop"
+ else [
+ [
+ torch.tensor([start_time_id, 0, 0]).unsqueeze(0),
+ torch.tensor([end_time_id, latent_height // 4, latent_width // 4]).unsqueeze(0),
+ torch.tensor([self.zip_frame_buckets[1], latent_height // 2, latent_width // 2]).unsqueeze(0),
+ ]
+ ]
+ )
+
+ start_time_id = -(self.zip_frame_buckets[:3].sum())
+ end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
+ grid_sizes_4x = [
+ [
+ torch.tensor([start_time_id, 0, 0]).unsqueeze(0),
+ torch.tensor([end_time_id, latent_height // 8, latent_width // 8]).unsqueeze(0),
+ torch.tensor([self.zip_frame_buckets[2], latent_height // 2, latent_width // 2]).unsqueeze(0),
+ ]
+ ]
+
+ grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
+
+ motion_rope_emb = self.rope(
+ motion_lat.detach().view(
+ motion_lat.shape[0],
+ motion_lat.shape[1],
+ self.num_attention_heads,
+ self.inner_dim // self.num_attention_heads,
+ ),
+ grid_sizes=grid_sizes,
+ )
+
+ return motion_lat, motion_rope_emb
+
+
+class Motioner(nn.Module):
+ def __init__(self, inner_dim, num_attention_heads, patch_size=(1, 2, 2), in_channels=16, rope_max_seq_len=1024):
+ super().__init__()
+ self.inner_dim = inner_dim
+ self.num_attention_heads = num_attention_heads
+
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+ self.rope = WanS2VRotaryPosEmbed(
+ inner_dim // num_attention_heads, patch_size, rope_max_seq_len, num_attention_heads
+ )
+
+ def forward(self, motion_latents):
+ latent_motion_frames = motion_latents.shape[2]
+ mot = self.patch_embedding(motion_latents)
+
+ height, width = mot.shape[3], mot.shape[4]
+ flat_mot = mot.flatten(2).transpose(1, 2).contiguous()
+ motion_grid_sizes = [
+ [
+ torch.tensor([-latent_motion_frames, 0, 0]).unsqueeze(0),
+ torch.tensor([0, height, width]).unsqueeze(0),
+ torch.tensor([latent_motion_frames, height, width]).unsqueeze(0),
+ ]
+ ]
+ motion_rope_emb = self.rope(
+ flat_mot.detach().view(
+ flat_mot.shape[0],
+ flat_mot.shape[1],
+ self.num_attention_heads,
+ self.inner_dim // self.num_attention_heads,
+ ),
+ motion_grid_sizes,
+ )
+
+ return flat_mot, motion_rope_emb
+
+
+class WanTimeTextAudioPoseEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ audio_embed_dim: int,
+ pose_embed_dim: int,
+ patch_size: Tuple[int],
+ enable_adain: bool,
+ num_weighted_avg_layers: int,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+ self.causal_audio_encoder = CausalAudioEncoder(
+ dim=audio_embed_dim,
+ num_weighted_avg_layers=num_weighted_avg_layers,
+ out_dim=dim,
+ num_audio_token=4,
+ need_global=enable_adain,
+ )
+ self.pose_embedder = nn.Conv3d(pose_embed_dim, dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ audio_hidden_states: torch.Tensor,
+ pose_hidden_states: Optional[torch.Tensor] = None,
+ timestep_seq_len: Optional[int] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+ if timestep_seq_len is not None:
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
+
+ time_embedder_dtype = get_parameter_dtype(self.time_embedder)
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+
+ audio_hidden_states = self.causal_audio_encoder(audio_hidden_states)
+
+ pose_hidden_states = self.pose_embedder(pose_hidden_states)
+
+ return temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states
+
+
+class WanS2VRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ num_attention_heads: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+ self.num_attention_heads = num_attention_heads
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq = get_1d_rotary_pos_embed(
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
+ )
+ freqs.append(freq)
+
+ self.freqs = torch.cat(freqs, dim=1)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ image_latents: Optional[torch.Tensor] = None,
+ grid_sizes: Optional[List[List[torch.Tensor]]] = None,
+ ) -> torch.Tensor:
+ if grid_sizes is None:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ grid_sizes = torch.tensor([ppf, pph, ppw]).unsqueeze(0).repeat(batch_size, 1)
+ grid_sizes = [torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]
+
+ image_grid_sizes = [
+ # The start index
+ torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
+ # The end index
+ torch.tensor([31, image_latents.shape[3] // p_h, image_latents.shape[4] // p_w])
+ .unsqueeze(0)
+ .repeat(batch_size, 1),
+ # The range
+ torch.tensor([1, image_latents.shape[3] // p_h, image_latents.shape[4] // p_w])
+ .unsqueeze(0)
+ .repeat(batch_size, 1),
+ ]
+
+ grids = [grid_sizes, image_grid_sizes]
+ S = ppf * pph * ppw + image_latents.shape[3] // p_h * image_latents.shape[4] // p_w
+ else: # FramePack's RoPE
+ batch_size, S, _, _ = hidden_states.shape
+ grids = grid_sizes
+
+ split_sizes = [
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
+ self.attention_head_dim // 6,
+ self.attention_head_dim // 6,
+ ]
+
+ freqs = self.freqs.split(split_sizes, dim=1)
+
+ # Loop over samples
+ output = torch.view_as_complex(
+ torch.zeros(
+ (batch_size, S, self.num_attention_heads, self.attention_head_dim // 2, 2),
+ device=hidden_states.device,
+ dtype=torch.float64,
+ )
+ )
+ seq_bucket = [0]
+ for g in grids:
+ if type(g) is not list:
+ g = [torch.zeros_like(g), g]
+ batch_size = g[0].shape[0]
+ for i in range(batch_size):
+ f_o, h_o, w_o = g[0][i]
+
+ f, h, w = g[1][i]
+ t_f, t_h, t_w = g[2][i]
+ seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
+ seq_len = int(seq_f * seq_h * seq_w)
+ if seq_len > 0:
+ if t_f > 0:
+ # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
+ if f_o >= 0:
+ f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
+ else:
+ f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
+ h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
+ w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
+
+ assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
+ freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
+ freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
+
+ freqs_i = torch.cat(
+ [
+ freqs_0.expand(seq_f, seq_h, seq_w, -1),
+ freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
+ freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
+ ],
+ dim=-1,
+ ).reshape(seq_len, 1, -1)
+
+ # apply rotary embedding
+ output[i, seq_bucket[-1] : seq_bucket[-1] + seq_len] = freqs_i
+ seq_bucket.append(seq_bucket[-1] + seq_len)
+
+ return output
+
+
+@maybe_allow_in_graph
+class WanS2VTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=WanS2VAttnProcessor(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ cross_attention_dim_head=dim // num_heads,
+ processor=WanS2VAttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: Tuple[torch.Tensor, torch.Tensor],
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ seg_idx = temb[1].item()
+ seg_idx = min(max(0, seg_idx), hidden_states.shape[1])
+ seg_idx = [0, seg_idx, hidden_states.shape[1]]
+ temb = temb[0]
+ # temb: batch_size, 6, 2, inner_dim
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(2) + temb.float()
+ ).chunk(6, dim=1)
+ # batch_size, 1, seq_len, inner_dim
+ shift_msa = shift_msa.squeeze(1)
+ scale_msa = scale_msa.squeeze(1)
+ gate_msa = gate_msa.squeeze(1)
+ c_shift_msa = c_shift_msa.squeeze(1)
+ c_scale_msa = c_scale_msa.squeeze(1)
+ c_gate_msa = c_gate_msa.squeeze(1)
+
+ norm_hidden_states = self.norm1(hidden_states.float())
+ parts = []
+ for i in range(2):
+ parts.append(
+ norm_hidden_states[:, seg_idx[i] : seg_idx[i + 1]] * (1 + scale_msa[:, i : i + 1])
+ + shift_msa[:, i : i + 1]
+ )
+ norm_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states)
+
+ # 1. Self-attention
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
+ z = []
+ for i in range(2):
+ z.append(attn_output[:, seg_idx[i] : seg_idx[i + 1]] * gate_msa[:, i : i + 1])
+ attn_output = torch.cat(z, dim=1)
+ hidden_states = (hidden_states.float() + attn_output).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm3_hidden_states = self.norm3(hidden_states.float())
+ parts = []
+ for i in range(2):
+ parts.append(
+ norm3_hidden_states[:, seg_idx[i] : seg_idx[i + 1]] * (1 + c_scale_msa[:, i : i + 1])
+ + c_shift_msa[:, i : i + 1]
+ )
+ norm3_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states)
+ ff_output = self.ffn(norm3_hidden_states)
+ z = []
+ for i in range(2):
+ z.append(ff_output[:, seg_idx[i] : seg_idx[i + 1]] * c_gate_msa[:, i : i + 1])
+ ff_output = torch.cat(z, dim=1)
+ hidden_states = (hidden_states.float() + ff_output.float()).type_as(hidden_states)
+
+ return hidden_states
+
+
+class WanS2VTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the Wan2.2-S2V model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ add_img_emb (`bool`, defaults to `False`):
+ Whether to use img_emb.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ zero_timestep (`bool`, defaults to `True`):
+ Whether to assign 0 value timestep to image/motion
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["WanS2VTransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3", "causal_audio_encoder"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["WanS2VTransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ audio_dim: int = 1024,
+ audio_inject_layers: Tuple[int] = (0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39),
+ enable_adain: bool = True,
+ adain_mode: str = "attn_norm",
+ pose_dim: int = 16,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ num_weighted_avg_layers: int = 25,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ enable_framepack: bool = True,
+ framepack_drop_mode: str = "padd",
+ add_last_motion: bool = True,
+ zero_timestep: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.rope = WanS2VRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len, num_attention_heads)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ if enable_framepack:
+ self.frame_packer = FramePackMotioner(
+ inner_dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ zip_frame_buckets=[1, 2, 16],
+ drop_mode=framepack_drop_mode,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ )
+ else:
+ self.motion_in = Motioner(
+ inner_dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ rope_max_seq_len=rope_max_seq_len,
+ )
+
+ self.trainable_condition_mask = nn.Embedding(3, inner_dim)
+
+ # 2. Condition Embeddings
+ self.condition_embedder = WanTimeTextAudioPoseEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ audio_embed_dim=audio_dim,
+ pose_embed_dim=pose_dim,
+ patch_size=patch_size,
+ enable_adain=enable_adain,
+ num_weighted_avg_layers=num_weighted_avg_layers,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ WanS2VTransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Audio Injector
+ self.audio_injector = AudioInjector(
+ num_injection_layers=len(audio_inject_layers),
+ inject_layers=audio_inject_layers,
+ dim=inner_dim,
+ num_heads=num_attention_heads,
+ enable_adain=enable_adain,
+ adain_dim=inner_dim,
+ adain_mode=adain_mode,
+ eps=eps,
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def process_motion(self, motion_latents, drop_motion_frames=False):
+ flattern_mot, mot_remb = self.motion_in(motion_latents)
+
+ if drop_motion_frames or motion_latents[0].shape[1] == 0:
+ return [], []
+ else:
+ return flattern_mot, mot_remb
+
+ def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2):
+ flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion)
+
+ if drop_motion_frames:
+ return flattern_mot[:, :0], mot_remb[:, :0]
+ else:
+ return flattern_mot, mot_remb
+
+ def inject_motion(
+ self,
+ hidden_states,
+ seq_lens,
+ rope_embs,
+ mask_input,
+ motion_latents,
+ drop_motion_frames=False,
+ add_last_motion=True,
+ ):
+ # Inject the motion frames token to the hidden states
+ if self.config.enable_framepack:
+ mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames, add_last_motion)
+ else:
+ mot, mot_remb = self.process_motion(motion_latents, drop_motion_frames)
+
+ if len(mot) > 0:
+ hidden_states = torch.cat([hidden_states, mot], dim=1)
+ seq_lens = seq_lens + torch.tensor([mot.shape[1]], dtype=torch.long)
+ rope_embs = torch.cat([rope_embs, mot_remb], dim=1)
+ mask_input = torch.cat(
+ [
+ mask_input,
+ 2
+ * torch.ones(
+ [1, hidden_states.shape[1] - mask_input.shape[1]],
+ device=mask_input.device,
+ dtype=mask_input.dtype,
+ ),
+ ],
+ dim=1,
+ )
+ return hidden_states, seq_lens, rope_embs, mask_input
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ motion_latents: torch.Tensor,
+ audio_embeds: torch.Tensor,
+ image_latents: torch.Tensor,
+ pose_latents: torch.Tensor,
+ motion_frames: List[int] = [17, 5],
+ drop_motion_frames: bool = False,
+ add_last_motion: int = 2,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ r"""
+ Parameters:
+ audio_embeds:
+ The input audio embedding [B, num_wav2vec_layer, C_a, T_a].
+ motion_frames:
+ The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5].
+ add_last_motion:
+ For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame)
+ will be added. For frame packing, the behavior depends on the value of add_last_motion: add_last_motion
+ = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. add_last_motion = 1:
+ Both clean_latents_2x and clean_latents_4x are included. add_last_motion = 2: All motion-related
+ latents are used.
+ drop_motion_frames:
+ Bool, whether drop the motion frames info.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+ add_last_motion = self.config.add_last_motion * add_last_motion
+
+ # 1. Rotary position embeddings
+ rotary_emb = self.rope(hidden_states, image_latents)
+
+ # 2. Patch embeddings
+ hidden_states = self.patch_embedding(hidden_states)
+ image_latents = self.patch_embedding(image_latents)
+
+ # 3. Condition embeddings
+ audio_embeds = torch.cat(
+ [audio_embeds[..., 0].unsqueeze(-1).repeat(1, 1, 1, motion_frames[0]), audio_embeds], dim=-1
+ )
+
+ if self.config.zero_timestep:
+ timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
+
+ temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states = self.condition_embedder(
+ timestep, encoder_hidden_states, audio_embeds, pose_latents
+ )
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if self.config.enable_adain:
+ audio_emb_global, audio_emb = audio_hidden_states
+ audio_emb_global = audio_emb_global[:, motion_frames[1] :].clone()
+ else:
+ audio_emb = audio_hidden_states
+ merged_audio_emb = audio_emb[:, motion_frames[1] :, :]
+
+ hidden_states = hidden_states + pose_hidden_states
+
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+ image_latents = image_latents.flatten(2).transpose(1, 2)
+
+ sequence_length = torch.tensor([hidden_states.shape[1]], dtype=torch.long)
+ original_sequence_length = sequence_length
+ sequence_length = sequence_length + torch.tensor([image_latents.shape[1]], dtype=torch.long)
+ hidden_states = torch.cat([hidden_states, image_latents], dim=1)
+
+ # Initialize masks to indicate noisy latent, image latent, and motion latent.
+ # However, at this point, only the first two (noisy and image latents) are marked;
+ # the marking of motion latent will be implemented inside `inject_motion`.
+ mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device)
+ mask_input[:, original_sequence_length:] = 1
+
+ hidden_states, sequence_length, rotary_emb, mask_input = self.inject_motion(
+ hidden_states,
+ sequence_length,
+ rotary_emb,
+ mask_input,
+ motion_latents,
+ drop_motion_frames,
+ add_last_motion,
+ )
+
+ hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype)
+
+ if self.config.zero_timestep:
+ temb = temb[:-1]
+ zero_timestep_proj = timestep_proj[-1:]
+ timestep_proj = timestep_proj[:-1]
+ timestep_proj = torch.cat(
+ [timestep_proj.unsqueeze(2), zero_timestep_proj.unsqueeze(2).repeat(timestep_proj.shape[0], 1, 1, 1)],
+ dim=2,
+ )
+ timestep_proj = [timestep_proj, original_sequence_length]
+ else:
+ timestep_proj = timestep_proj.unsqueeze(2).repeat(1, 1, 2, 1)
+ timestep_proj = [timestep_proj, 0]
+
+ merged_audio_emb_num_frames = merged_audio_emb.shape[1] # B F N C
+ attn_audio_emb = merged_audio_emb.flatten(0, 1).to(hidden_states.dtype)
+ audio_emb_global = audio_emb_global.flatten(0, 1).to(hidden_states.dtype)
+
+ # 5. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block_idx, block in enumerate(self.blocks):
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ )
+ if block_idx in self.audio_injector.injected_block_id.keys():
+ hidden_states = self.audio_injector(
+ block_idx,
+ hidden_states,
+ original_sequence_length,
+ merged_audio_emb_num_frames,
+ attn_audio_emb,
+ audio_emb_global,
+ )
+ else:
+ for block_idx, block in enumerate(self.blocks):
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+ if block_idx in self.audio_injector.injected_block_id.keys():
+ hidden_states = self.audio_injector(
+ block_idx,
+ hidden_states,
+ original_sequence_length,
+ merged_audio_emb_num_frames,
+ attn_audio_emb,
+ audio_emb_global,
+ )
+
+ hidden_states = hidden_states[:, :original_sequence_length]
+
+ # 6. Output norm, projection & unpatchify
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py
index 151adbbc0320..c7285e38fda2 100644
--- a/src/diffusers/modular_pipelines/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/modular_pipeline.py
@@ -861,6 +861,10 @@ def __init__(self):
else:
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
+ if not len(self.block_names) == len(self.block_classes):
+ raise ValueError(
+ f"In {self.__class__.__name__}, the number of block_names and block_classes must be the same."
+ )
def _get_inputs(self):
inputs = []
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
index 83bfcb3da4fd..419894164389 100644
--- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
@@ -523,7 +523,7 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
QwenImageOptionalControlNetBeforeDenoiseStep,
QwenImageAutoDenoiseStep,
]
- block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise", "decode"]
+ block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"]
@property
def description(self):
@@ -534,7 +534,6 @@ def description(self):
+ " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n"
+ " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
- + " - `QwenImageAutoDecodeStep` (decode) decodes the latents into images.\n\n"
+ "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 495753041f10..d21efe20ee4e 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -308,7 +308,10 @@
"SanaSprintPipeline",
"SanaControlNetPipeline",
"SanaSprintImg2ImgPipeline",
+ ]
+ _import_structure["sana_video"] = [
"SanaVideoPipeline",
+ "SanaImageToVideoPipeline",
]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
@@ -385,7 +388,14 @@
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
- _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
+ _import_structure["wan"] = [
+ "WanPipeline",
+ "WanImageToVideoPipeline",
+ "WanVideoToVideoPipeline",
+ "WanVACEPipeline",
+ "WanSpeechToVideoPipeline",
+ "WanAnimatePipeline",
+ ]
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
@@ -743,8 +753,8 @@
SanaPipeline,
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
- SanaVideoPipeline,
)
+ from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
@@ -803,7 +813,14 @@
UniDiffuserTextDecoder,
)
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
- from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
+ from .wan import (
+ WanAnimatePipeline,
+ WanImageToVideoPipeline,
+ WanPipeline,
+ WanSpeechToVideoPipeline,
+ WanVACEPipeline,
+ WanVideoToVideoPipeline,
+ )
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py
index ebddfb0c0eee..a22a756005ac 100644
--- a/src/diffusers/pipelines/bria/pipeline_bria.py
+++ b/src/diffusers/pipelines/bria/pipeline_bria.py
@@ -245,7 +245,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -489,11 +489,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
index 85d29029e667..c66b64766edc 100644
--- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
+++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
@@ -337,7 +337,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
@@ -498,11 +498,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index bd23e657c408..8ca8b4419e18 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -590,9 +590,10 @@ def __call__(
the text `prompt`, usually at the expense of lower image quality.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -777,7 +778,7 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.guidance_rescale > 0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
index 537588f67c95..48a6f0837c8d 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -927,9 +927,10 @@ def __call__(
the text `prompt`, usually at the expense of lower image quality.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -1194,7 +1195,7 @@ def __call__(
timestep, _ = timestep.chunk(2)
if self.guidance_rescale > 0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 694378b4f040..f30f8a3dc8f6 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -654,9 +654,10 @@ def __call__(
the text `prompt`, usually at the expense of lower image quality.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -851,7 +852,7 @@ def __call__(
timestep, _ = timestep.chunk(2)
if self.guidance_rescale > 0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
)
diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py
index a3bd3e6b45e7..873f25316e6d 100644
--- a/src/diffusers/pipelines/prx/pipeline_prx.py
+++ b/src/diffusers/pipelines/prx/pipeline_prx.py
@@ -69,6 +69,39 @@
"2.0": [704, 352],
}
+ASPECT_RATIO_1024_BIN = {
+ "0.49": [704, 1440],
+ "0.52": [736, 1408],
+ "0.53": [736, 1376],
+ "0.57": [768, 1344],
+ "0.59": [768, 1312],
+ "0.62": [800, 1280],
+ "0.67": [832, 1248],
+ "0.68": [832, 1216],
+ "0.78": [896, 1152],
+ "0.83": [928, 1120],
+ "0.94": [992, 1056],
+ "1.0": [1024, 1024],
+ "1.06": [1056, 992],
+ "1.13": [1088, 960],
+ "1.21": [1120, 928],
+ "1.29": [1152, 896],
+ "1.37": [1184, 864],
+ "1.46": [1216, 832],
+ "1.5": [1248, 832],
+ "1.71": [1312, 768],
+ "1.75": [1344, 768],
+ "1.87": [1376, 736],
+ "1.91": [1408, 736],
+ "2.05": [1440, 704],
+}
+
+ASPECT_RATIO_BINS = {
+ 256: ASPECT_RATIO_256_BIN,
+ 512: ASPECT_RATIO_512_BIN,
+ 1024: ASPECT_RATIO_1024_BIN,
+}
+
logger = logging.get_logger(__name__)
@@ -536,11 +569,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -600,10 +633,12 @@ def __call__(
"Resolution binning requires a VAE with image_processor, but VAE is not available. "
"Set use_resolution_binning=False or provide a VAE."
)
- if self.default_sample_size <= 256:
- aspect_ratio_bin = ASPECT_RATIO_256_BIN
- else:
- aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ if self.default_sample_size not in ASPECT_RATIO_BINS:
+ raise ValueError(
+ f"Resolution binning is only supported for default_sample_size in {list(ASPECT_RATIO_BINS.keys())}, "
+ f"but got {self.default_sample_size}. Set use_resolution_binning=False to disable aspect ratio binning."
+ )
+ aspect_ratio_bin = ASPECT_RATIO_BINS[self.default_sample_size]
# Store original dimensions
orig_height, orig_width = height, width
diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py
index d5571ab12fac..91684f35f153 100644
--- a/src/diffusers/pipelines/sana/__init__.py
+++ b/src/diffusers/pipelines/sana/__init__.py
@@ -26,7 +26,6 @@
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
- _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -40,7 +39,6 @@
from .pipeline_sana_controlnet import SanaControlNetPipeline
from .pipeline_sana_sprint import SanaSprintPipeline
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
- from .pipeline_sana_video import SanaVideoPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/sana/pipeline_output.py b/src/diffusers/pipelines/sana/pipeline_output.py
index 8021b7738755..f8ac12951644 100644
--- a/src/diffusers/pipelines/sana/pipeline_output.py
+++ b/src/diffusers/pipelines/sana/pipeline_output.py
@@ -3,7 +3,6 @@
import numpy as np
import PIL.Image
-import torch
from ...utils import BaseOutput
@@ -20,18 +19,3 @@ class SanaPipelineOutput(BaseOutput):
"""
images: Union[List[PIL.Image.Image], np.ndarray]
-
-
-@dataclass
-class SanaVideoPipelineOutput(BaseOutput):
- r"""
- Output class for Sana-Video pipelines.
-
- Args:
- frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
- denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
- `(batch_size, num_frames, channels, height, width)`.
- """
-
- frames: torch.Tensor
diff --git a/src/diffusers/pipelines/sana_video/__init__.py b/src/diffusers/pipelines/sana_video/__init__.py
new file mode 100644
index 000000000000..73e224bf749d
--- /dev/null
+++ b/src/diffusers/pipelines/sana_video/__init__.py
@@ -0,0 +1,49 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
+ _import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_sana_video import SanaVideoPipeline
+ from .pipeline_sana_video_i2v import SanaImageToVideoPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/sana_video/pipeline_output.py b/src/diffusers/pipelines/sana_video/pipeline_output.py
new file mode 100644
index 000000000000..4d37923889eb
--- /dev/null
+++ b/src/diffusers/pipelines/sana_video/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class SanaVideoPipelineOutput(BaseOutput):
+ r"""
+ Output class for Sana-Video pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py
similarity index 98%
rename from src/diffusers/pipelines/sana/pipeline_sana_video.py
rename to src/diffusers/pipelines/sana_video/pipeline_sana_video.py
index 5ec498faffb9..a786275e45a9 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_video.py
+++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py
@@ -95,17 +95,16 @@
>>> from diffusers import SanaVideoPipeline
>>> from diffusers.utils import export_to_video
- >>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
- >>> pipe = SanaVideoPipeline.from_pretrained(model_id)
+ >>> pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers")
>>> pipe.transformer.to(torch.bfloat16)
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.vae.to(torch.float32)
>>> pipe.to("cuda")
- >>> model_score = 30
+ >>> motion_score = 30
>>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
>>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
- >>> motion_prompt = f" motion score: {model_score}."
+ >>> motion_prompt = f" motion score: {motion_score}."
>>> prompt = prompt + motion_prompt
>>> output = pipe(
@@ -231,6 +230,7 @@ def __init__(
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
@@ -827,9 +827,9 @@ def __call__(
Examples:
Returns:
- [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
- If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] is returned,
- otherwise a `tuple` is returned where the first element is a list with the generated videos
+ [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated videos
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py
new file mode 100644
index 000000000000..e87880b64cee
--- /dev/null
+++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py
@@ -0,0 +1,1066 @@
+# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import SanaLoraLoaderMixin
+from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SanaVideoPipelineOutput
+from .pipeline_sana_video import ASPECT_RATIO_480_BIN, ASPECT_RATIO_720_BIN
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaImageToVideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> pipe = SanaImageToVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers")
+ >>> pipe.transformer.to(torch.bfloat16)
+ >>> pipe.text_encoder.to(torch.bfloat16)
+ >>> pipe.vae.to(torch.float32)
+ >>> pipe.to("cuda")
+ >>> motion_score = 30
+
+ >>> prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
+ >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+ >>> motion_prompt = f" motion score: {motion_score}."
+ >>> prompt = prompt + motion_prompt
+ >>> image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... frames=81,
+ ... guidance_scale=6,
+ ... num_inference_steps=50,
+ ... generator=torch.Generator(device="cuda").manual_seed(42),
+ ... ).frames[0]
+
+ >>> export_to_video(output, "sana-ti2v-output.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
+ r"""
+ Pipeline for image/text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model
+ inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all
+ pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]):
+ The tokenizer used to tokenize the prompt.
+ text_encoder ([`Gemma2PreTrainedModel`]):
+ Text encoder model to encode the input prompts.
+ vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer ([`SanaVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: Union[AutoencoderDC, AutoencoderKLWan],
+ transformer: SanaVideoTransformer3DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.vae_scale_factor = self.vae_scale_factor_spatial
+
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size[1] if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size[0] if getattr(self, "transformer") is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ dtype: torch.dtype,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_sequence_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.sana_video.pipeline_sana_video.SanaVideoPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ number of videos that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ prompt_embeds = prompt_embeds[:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=False,
+ )
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip addresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ image = image.unsqueeze(2) # [B, C, 1, H, W]
+ image = image.to(device=device, dtype=self.vae.dtype)
+
+ if isinstance(generator, list):
+ image_latents = [retrieve_latents(self.vae.encode(image), sample_mode="argmax") for _ in generator]
+ image_latents = torch.cat(image_latents)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
+ image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, -1, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
+ image_latents.device, image_latents.dtype
+ )
+ image_latents = (image_latents - latents_mean) * latents_std
+
+ latents[:, :, 0:1] = image_latents.to(dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ height: int = 480,
+ width: int = 832,
+ frames: int = 81,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ ) -> Union[SanaVideoPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the video generation on. The first frame of the generated video will be
+ conditioned on this image.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to
+ the text `prompt`, usually at the expense of lower video quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ height (`int`, *optional*, defaults to 480):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to 832):
+ The width in pixels of the generated video.
+ frames (`int`, *optional*, defaults to 81):
+ The number of frames in the generated video.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between mp4 or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos,
+ they are resized back to the requested resolution. Useful for generating non-square videos.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `300`):
+ Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated videos
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 30:
+ aspect_ratio_bin = ASPECT_RATIO_480_BIN
+ elif self.transformer.config.sample_size == 22:
+ aspect_ratio_bin = ASPECT_RATIO_720_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ image,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+
+ latents = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ height,
+ width,
+ frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ conditioning_mask = latents.new_zeros(
+ batch_size,
+ 1,
+ latents.shape[2] // self.transformer_temporal_patch_size,
+ latents.shape[3] // self.transformer_spatial_patch_size,
+ latents.shape[4] // self.transformer_spatial_patch_size,
+ )
+ conditioning_mask[:, :, 0] = 1.0
+ if self.do_classifier_free_guidance:
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ transformer_dtype = self.transformer.dtype
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(conditioning_mask.shape)
+ timestep = timestep * (1 - conditioning_mask)
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input.to(dtype=transformer_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timestep,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ timestep, _ = timestep.chunk(2)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+
+ noise_pred = noise_pred[:, :, 1:]
+ noise_latents = latents[:, :, 1:]
+ pred_latents = self.scheduler.step(
+ noise_pred, t, noise_latents, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ try:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ except oom_error as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+
+ if use_resolution_binning:
+ video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height)
+
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SanaVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
index 8562a5eaf0e6..d6cd7d7feceb 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
@@ -415,11 +415,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `6.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
index d0a4e118ce43..089f92632d38 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
@@ -647,11 +647,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `6.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
index 959cbb32f23a..2951a9447386 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
@@ -698,11 +698,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
index d59b4ce3cb17..d61b687eadc3 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
@@ -524,11 +524,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py
index bb96372b1db2..324242ab477c 100644
--- a/src/diffusers/pipelines/wan/__init__.py
+++ b/src/diffusers/pipelines/wan/__init__.py
@@ -23,7 +23,9 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_wan"] = ["WanPipeline"]
+ _import_structure["pipeline_wan_animate"] = ["WanAnimatePipeline"]
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
+ _import_structure["pipeline_wan_s2v"] = ["WanSpeechToVideoPipeline"]
_import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"]
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -35,10 +37,11 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_wan import WanPipeline
+ from .pipeline_wan_animate import WanAnimatePipeline
from .pipeline_wan_i2v import WanImageToVideoPipeline
+ from .pipeline_wan_s2v import WanSpeechToVideoPipeline
from .pipeline_wan_vace import WanVACEPipeline
from .pipeline_wan_video2video import WanVideoToVideoPipeline
-
else:
import sys
diff --git a/src/diffusers/pipelines/wan/image_processor.py b/src/diffusers/pipelines/wan/image_processor.py
new file mode 100644
index 000000000000..b1594d08630f
--- /dev/null
+++ b/src/diffusers/pipelines/wan/image_processor.py
@@ -0,0 +1,185 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from ...configuration_utils import register_to_config
+from ...image_processor import VaeImageProcessor
+from ...utils import PIL_INTERPOLATION
+
+
+class WanAnimateImageProcessor(VaeImageProcessor):
+ r"""
+ Image processor to preprocess the reference (character) image for the Wan Animate model.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
+ VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
+ this factor.
+ vae_latent_channels (`int`, *optional*, defaults to `16`):
+ VAE latent channels.
+ spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(2, 2)`):
+ The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2).
+ resample (`str`, *optional*, defaults to `lanczos`):
+ Resampling filter to use when resizing the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image to [-1,1].
+ do_binarize (`bool`, *optional*, defaults to `False`):
+ Whether to binarize the image to 0/1.
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
+ Whether to convert the images to RGB format.
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
+ Whether to convert the images to grayscale format.
+ fill_color (`str` or `float` or `Tuple[float, ...]`, *optional*, defaults to `None`):
+ An optional fill color when `resize_mode` is set to `"fill"`. This will fill the empty space with that
+ color instead of filling with data from the image. Any valid `color` argument to `PIL.Image.new` is valid;
+ if `None`, will default to filling with data from `image`.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 8,
+ vae_latent_channels: int = 16,
+ spatial_patch_size: Tuple[int, int] = (2, 2),
+ resample: str = "lanczos",
+ reducing_gap: int = None,
+ do_normalize: bool = True,
+ do_binarize: bool = False,
+ do_convert_rgb: bool = False,
+ do_convert_grayscale: bool = False,
+ fill_color: Optional[Union[str, float, Tuple[float, ...]]] = 0,
+ ):
+ super().__init__()
+ if do_convert_rgb and do_convert_grayscale:
+ raise ValueError(
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
+ )
+
+ def _resize_and_fill(
+ self,
+ image: PIL.Image.Image,
+ width: int,
+ height: int,
+ ) -> PIL.Image.Image:
+ r"""
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
+ the image within the dimensions, filling empty with data from image.
+
+ Args:
+ image (`PIL.Image.Image`):
+ The image to resize and fill.
+ width (`int`):
+ The width to resize the image to.
+ height (`int`):
+ The height to resize the image to.
+
+ Returns:
+ `PIL.Image.Image`:
+ The resized and filled image.
+ """
+
+ ratio = width / height
+ src_ratio = image.width / image.height
+ fill_with_image_data = self.config.fill_color is None
+ fill_color = self.config.fill_color or 0
+
+ src_w = width if ratio < src_ratio else image.width * height // image.height
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
+
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
+ res = PIL.Image.new("RGB", (width, height), color=fill_color)
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+
+ if fill_with_image_data:
+ if ratio < src_ratio:
+ fill_height = height // 2 - src_h // 2
+ if fill_height > 0:
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
+ box=(0, fill_height + src_h),
+ )
+ elif ratio > src_ratio:
+ fill_width = width // 2 - src_w // 2
+ if fill_width > 0:
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
+ box=(fill_width + src_w, 0),
+ )
+
+ return res
+
+ def get_default_height_width(
+ self,
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ ) -> Tuple[int, int]:
+ r"""
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
+
+ Args:
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
+ tensor, it should have shape `[batch, channels, height, width]`.
+ height (`Optional[int]`, *optional*, defaults to `None`):
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
+ width (`Optional[int]`, *optional*, defaults to `None`):
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
+
+ Returns:
+ `Tuple[int, int]`:
+ A tuple containing the height and width, both resized to the nearest integer multiple of
+ `vae_scale_factor * spatial_patch_size`.
+ """
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[2]
+ else:
+ height = image.shape[1]
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[3]
+ else:
+ width = image.shape[2]
+
+ max_area = width * height
+ aspect_ratio = height / width
+ mod_value_h = self.config.vae_scale_factor * self.config.spatial_patch_size[0]
+ mod_value_w = self.config.vae_scale_factor * self.config.spatial_patch_size[1]
+
+ # Try to preserve the aspect ratio
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_h * mod_value_h
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_w * mod_value_w
+
+ return height, width
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_animate.py b/src/diffusers/pipelines/wan/pipeline_wan_animate.py
new file mode 100644
index 000000000000..c7c983b2f7d4
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan_animate.py
@@ -0,0 +1,1204 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import regex as re
+import torch
+import torch.nn.functional as F
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanAnimateTransformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .image_processor import WanAnimateImageProcessor
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> import numpy as np
+ >>> from diffusers import WanAnimatePipeline
+ >>> from diffusers.utils import export_to_video, load_image, load_video
+
+ >>> model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
+ >>> pipe = WanAnimatePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> # Optionally upcast the Wan VAE to FP32
+ >>> pipe.vae.to(torch.float32)
+ >>> pipe.to("cuda")
+
+ >>> # Load the reference character image
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+
+ >>> # Load pose and face videos (preprocessed from reference video)
+ >>> # Note: Videos should be preprocessed to extract pose keypoints and face features
+ >>> # Refer to the Wan-Animate preprocessing documentation for details
+ >>> pose_video = load_video("path/to/pose_video.mp4")
+ >>> face_video = load_video("path/to/face_video.mp4")
+
+ >>> # CFG is generally not used for Wan Animate
+ >>> prompt = (
+ ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+ ... )
+
+ >>> # Animation mode: Animate the character with the motion from pose/face videos
+ >>> output = pipe(
+ ... image=image,
+ ... pose_video=pose_video,
+ ... face_video=face_video,
+ ... prompt=prompt,
+ ... height=height,
+ ... width=width,
+ ... segment_frame_length=77, # Frame length of each inference segment
+ ... guidance_scale=1.0,
+ ... num_inference_steps=20,
+ ... mode="animate",
+ ... ).frames[0]
+ >>> export_to_video(output, "output_animation.mp4", fps=30)
+
+ >>> # Replacement mode: Replace a character in the background video
+ >>> # Requires additional background_video and mask_video inputs
+ >>> background_video = load_video("path/to/background_video.mp4")
+ >>> mask_video = load_video("path/to/mask_video.mp4") # Black areas preserved, white areas generated
+ >>> output = pipe(
+ ... image=image,
+ ... pose_video=pose_video,
+ ... face_video=face_video,
+ ... background_video=background_video,
+ ... mask_video=mask_video,
+ ... prompt=prompt,
+ ... height=height,
+ ... width=width,
+ ... segment_frame_length=77, # Frame length of each inference segment
+ ... guidance_scale=1.0,
+ ... num_inference_steps=20,
+ ... mode="replace",
+ ... ).frames[0]
+ >>> export_to_video(output, "output_replacement.mp4", fps=30)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class WanAnimatePipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for unified character animation and replacement using Wan-Animate.
+
+ WanAnimatePipeline takes a character image, pose video, and face video as input, and generates a video in two
+ modes:
+
+ 1. **Animation mode**: The model generates a video of the character image that mimics the human motion in the input
+ pose and face videos. The character is animated based on the provided motion controls, creating a new animated
+ video of the character.
+
+ 2. **Replacement mode**: The model replaces a character in a background video with the provided character image,
+ using the pose and face videos for motion control. This mode requires additional `background_video` and
+ `mask_video` inputs. The mask video should have black regions where the original content should be preserved and
+ white regions where the new character should be generated.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.WanLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
+ the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`WanAnimateTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ image_processor ([`CLIPImageProcessor`]):
+ Image processor for preprocessing images before encoding.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ image_processor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModel,
+ transformer: WanAnimateTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.video_processor_for_mask = VideoProcessor(
+ vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_convert_grayscale=True
+ )
+ # In case self.transformer is None (e.g. for some pipeline tests)
+ spatial_patch_size = self.transformer.config.patch_size[-2:] if self.transformer is not None else (2, 2)
+ self.vae_image_processor = WanAnimateImageProcessor(
+ vae_scale_factor=self.vae_scale_factor_spatial,
+ spatial_patch_size=spatial_patch_size,
+ resample="bilinear",
+ fill_color=0,
+ )
+ self.image_processor = image_processor
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image
+ def encode_image(
+ self,
+ image: PipelineImageInput,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ pose_video,
+ face_video,
+ background_video,
+ mask_video,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ mode=None,
+ prev_segment_conditioning_frames=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if pose_video is None:
+ raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.")
+ if face_video is None:
+ raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.")
+ if not isinstance(pose_video, list) or not isinstance(face_video, list):
+ raise ValueError("`pose_video` and `face_video` must be lists of PIL images.")
+ if len(pose_video) == 0 or len(face_video) == 0:
+ raise ValueError("`pose_video` and `face_video` must contain at least one frame.")
+ if mode == "replace" and (background_video is None or mask_video is None):
+ raise ValueError(
+ "Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video`"
+ " undefined when mode is `replace`."
+ )
+ if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)):
+ raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.")
+
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found"
+ f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if mode is not None and (not isinstance(mode, str) or mode not in ("animate", "replace")):
+ raise ValueError(
+ f"`mode` has to be of type `str` and in ('animate', 'replace') but its type is {type(mode)} and value is {mode}"
+ )
+
+ if prev_segment_conditioning_frames is not None and (
+ not isinstance(prev_segment_conditioning_frames, int) or prev_segment_conditioning_frames not in (1, 5)
+ ):
+ raise ValueError(
+ f"`prev_segment_conditioning_frames` has to be of type `int` and 1 or 5 but its type is"
+ f" {type(prev_segment_conditioning_frames)} and value is {prev_segment_conditioning_frames}"
+ )
+
+ def get_i2v_mask(
+ self,
+ batch_size: int,
+ latent_t: int,
+ latent_h: int,
+ latent_w: int,
+ mask_len: int = 1,
+ mask_pixel_values: Optional[torch.Tensor] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Union[str, torch.device] = "cuda",
+ ) -> torch.Tensor:
+ # mask_pixel_values shape (if supplied): [B, C = 1, T, latent_h, latent_w]
+ if mask_pixel_values is None:
+ mask_lat_size = torch.zeros(
+ batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, dtype=dtype, device=device
+ )
+ else:
+ mask_lat_size = mask_pixel_values.clone().to(device=device, dtype=dtype)
+ mask_lat_size[:, :, :mask_len] = 1
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ # Repeat first frame mask self.vae_scale_factor_temporal (= 4) times in the frame dimension
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2)
+ mask_lat_size = mask_lat_size.view(
+ batch_size, -1, self.vae_scale_factor_temporal, latent_h, latent_w
+ ).transpose(1, 2) # [B, C = 1, 4 * T_lat, H_lat, W_lat] --> [B, C = 4, T_lat, H_lat, W_lat]
+
+ return mask_lat_size
+
+ def prepare_reference_image_latents(
+ self,
+ image: torch.Tensor,
+ batch_size: int = 1,
+ sample_mode: int = "argmax",
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ # image shape: (B, C, H, W) or (B, C, T, H, W)
+ dtype = dtype or self.vae.dtype
+ if image.ndim == 4:
+ # Add a singleton frame dimension after the channels dimension
+ image = image.unsqueeze(2)
+
+ _, _, _, height, width = image.shape
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ # Encode image to latents using VAE
+ image = image.to(device=device, dtype=dtype)
+ if isinstance(generator, list):
+ # Like in prepare_latents, assume len(generator) == batch_size
+ ref_image_latents = [
+ retrieve_latents(self.vae.encode(image), generator=g, sample_mode=sample_mode) for g in generator
+ ]
+ ref_image_latents = torch.cat(ref_image_latents)
+ else:
+ ref_image_latents = retrieve_latents(self.vae.encode(image), generator, sample_mode)
+ # Standardize latents in preparation for Wan VAE encode
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(ref_image_latents.device, ref_image_latents.dtype)
+ )
+ latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ ref_image_latents.device, ref_image_latents.dtype
+ )
+ ref_image_latents = (ref_image_latents - latents_mean) * latents_recip_std
+ # Handle the case where we supply one image and one generator, but batch_size > 1 (e.g. generating multiple
+ # videos per prompt)
+ if ref_image_latents.shape[0] == 1 and batch_size > 1:
+ ref_image_latents = ref_image_latents.expand(batch_size, -1, -1, -1, -1)
+
+ # Prepare I2V mask in latent space and prepend to the reference image latents along channel dim
+ reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, dtype, device)
+ reference_image_latents = torch.cat([reference_image_mask, ref_image_latents], dim=1)
+
+ return reference_image_latents
+
+ def prepare_prev_segment_cond_latents(
+ self,
+ prev_segment_cond_video: Optional[torch.Tensor] = None,
+ background_video: Optional[torch.Tensor] = None,
+ mask_video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ segment_frame_length: int = 77,
+ start_frame: int = 0,
+ height: int = 720,
+ width: int = 1280,
+ prev_segment_cond_frames: int = 1,
+ task: str = "animate",
+ interpolation_mode: str = "bicubic",
+ sample_mode: str = "argmax",
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ # prev_segment_cond_video shape: (B, C, T, H, W) in pixel space if supplied
+ # background_video shape: (B, C, T, H, W) (same as prev_segment_cond_video shape)
+ # mask_video shape: (B, 1, T, H, W) (same as prev_segment_cond_video, but with only 1 channel)
+ dtype = dtype or self.vae.dtype
+ if prev_segment_cond_video is None:
+ if task == "replace":
+ prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames].to(dtype)
+ else:
+ cond_frames_shape = (batch_size, 3, prev_segment_cond_frames, height, width) # In pixel space
+ prev_segment_cond_video = torch.zeros(cond_frames_shape, dtype=dtype, device=device)
+
+ data_batch_size, channels, _, segment_height, segment_width = prev_segment_cond_video.shape
+ num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ if segment_height != height or segment_width != width:
+ print(
+ f"Interpolating prev segment cond video from ({segment_width}, {segment_height}) to ({width}, {height})"
+ )
+ # Perform a 4D (spatial) rather than a 5D (spatiotemporal) reshape, following the original code
+ prev_segment_cond_video = prev_segment_cond_video.transpose(1, 2).flatten(0, 1) # [B * T, C, H, W]
+ prev_segment_cond_video = F.interpolate(
+ prev_segment_cond_video, size=(height, width), mode=interpolation_mode
+ )
+ prev_segment_cond_video = prev_segment_cond_video.unflatten(0, (batch_size, -1)).transpose(1, 2)
+
+ # Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if
+ # replacing).
+ if task == "replace":
+ remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype)
+ else:
+ remaining_segment_frames = segment_frame_length - prev_segment_cond_frames
+ remaining_segment = torch.zeros(
+ batch_size, channels, remaining_segment_frames, height, width, dtype=dtype, device=device
+ )
+
+ # Prepend the conditioning frames from the previous segment to the remaining segment video in the frame dim
+ prev_segment_cond_video = prev_segment_cond_video.to(dtype=dtype)
+ full_segment_cond_video = torch.cat([prev_segment_cond_video, remaining_segment], dim=2)
+
+ if isinstance(generator, list):
+ if data_batch_size == len(generator):
+ prev_segment_cond_latents = [
+ retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode)
+ for i, g in enumerate(generator)
+ ]
+ elif data_batch_size == 1:
+ # Like prepare_latents, assume len(generator) == batch_size
+ prev_segment_cond_latents = [
+ retrieve_latents(self.vae.encode(full_segment_cond_video), g, sample_mode) for g in generator
+ ]
+ else:
+ raise ValueError(
+ f"The batch size of the prev segment video should be either {len(generator)} or 1 but is"
+ f" {data_batch_size}"
+ )
+ prev_segment_cond_latents = torch.cat(prev_segment_cond_latents)
+ else:
+ prev_segment_cond_latents = retrieve_latents(
+ self.vae.encode(full_segment_cond_video), generator, sample_mode
+ )
+ # Standardize latents in preparation for Wan VAE encode
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(prev_segment_cond_latents.device, prev_segment_cond_latents.dtype)
+ )
+ latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ prev_segment_cond_latents.device, prev_segment_cond_latents.dtype
+ )
+ prev_segment_cond_latents = (prev_segment_cond_latents - latents_mean) * latents_recip_std
+
+ # Prepare I2V mask
+ if task == "replace":
+ mask_video = 1 - mask_video
+ mask_video = mask_video.permute(0, 2, 1, 3, 4)
+ mask_video = mask_video.flatten(0, 1)
+ mask_video = F.interpolate(mask_video, size=(latent_height, latent_width), mode="nearest")
+ mask_pixel_values = mask_video.unflatten(0, (batch_size, -1))
+ mask_pixel_values = mask_pixel_values.permute(0, 2, 1, 3, 4) # output shape: [B, C = 1, T, H_lat, W_lat]
+ else:
+ mask_pixel_values = None
+ prev_segment_cond_mask = self.get_i2v_mask(
+ batch_size,
+ num_latent_frames,
+ latent_height,
+ latent_width,
+ mask_len=prev_segment_cond_frames if start_frame > 0 else 0,
+ mask_pixel_values=mask_pixel_values,
+ dtype=dtype,
+ device=device,
+ )
+
+ # Prepend cond I2V mask to prev segment cond latents along channel dimension
+ prev_segment_cond_latents = torch.cat([prev_segment_cond_mask, prev_segment_cond_latents], dim=1)
+ return prev_segment_cond_latents
+
+ def prepare_pose_latents(
+ self,
+ pose_video: torch.Tensor,
+ batch_size: int = 1,
+ sample_mode: int = "argmax",
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ # pose_video shape: (B, C, T, H, W)
+ pose_video = pose_video.to(device=device, dtype=dtype if dtype is not None else self.vae.dtype)
+ if isinstance(generator, list):
+ pose_latents = [
+ retrieve_latents(self.vae.encode(pose_video), generator=g, sample_mode=sample_mode) for g in generator
+ ]
+ pose_latents = torch.cat(pose_latents)
+ else:
+ pose_latents = retrieve_latents(self.vae.encode(pose_video), generator, sample_mode)
+ # Standardize latents in preparation for Wan VAE encode
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(pose_latents.device, pose_latents.dtype)
+ )
+ latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ pose_latents.device, pose_latents.dtype
+ )
+ pose_latents = (pose_latents - latents_mean) * latents_recip_std
+ if pose_latents.shape[0] == 1 and batch_size > 1:
+ pose_latents = pose_latents.expand(batch_size, -1, -1, -1, -1)
+ return pose_latents
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 77,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames + 1, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents
+
+ def pad_video_frames(self, frames: List[Any], num_target_frames: int) -> List[Any]:
+ """
+ Pads an array-like video `frames` to `num_target_frames` using a "reflect"-like strategy. The frame dimension
+ is assumed to be the first dimension. In the 1D case, we can visualize this strategy as follows:
+
+ pad_video_frames([1, 2, 3, 4, 5], 10) -> [1, 2, 3, 4, 5, 4, 3, 2, 1, 2]
+ """
+ idx = 0
+ flip = False
+ target_frames = []
+ while len(target_frames) < num_target_frames:
+ target_frames.append(deepcopy(frames[idx]))
+ if flip:
+ idx -= 1
+ else:
+ idx += 1
+ if idx == 0 or idx == len(frames) - 1:
+ flip = not flip
+
+ return target_frames
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ pose_video: List[PIL.Image.Image],
+ face_video: List[PIL.Image.Image],
+ background_video: Optional[List[PIL.Image.Image]] = None,
+ mask_video: Optional[List[PIL.Image.Image]] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 720,
+ width: int = 1280,
+ segment_frame_length: int = 77,
+ num_inference_steps: int = 20,
+ mode: str = "animate",
+ prev_segment_conditioning_frames: int = 1,
+ motion_encode_batch_size: Optional[int] = None,
+ guidance_scale: float = 1.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input character image to condition the generation on. Must be an image, a list of images or a
+ `torch.Tensor`.
+ pose_video (`List[PIL.Image.Image]`):
+ The input pose video to condition the generation on. Must be a list of PIL images.
+ face_video (`List[PIL.Image.Image]`):
+ The input face video to condition the generation on. Must be a list of PIL images.
+ background_video (`List[PIL.Image.Image]`, *optional*):
+ When mode is `"replace"`, the input background video to condition the generation on. Must be a list of
+ PIL images.
+ mask_video (`List[PIL.Image.Image]`, *optional*):
+ When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of PIL
+ images.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ mode (`str`, defaults to `"animation"`):
+ The mode of the generation. Choose between `"animate"` and `"replace"`.
+ prev_segment_conditioning_frames (`int`, defaults to `1`):
+ The number of frames from the previous video segment to be used for temporal guidance. Recommended to
+ be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer.
+ motion_encode_batch_size (`int`, *optional*):
+ The batch size for batched encoding of the face video via the motion encoder. This allows trading off
+ inference speed for lower memory usage by setting a smaller batch size. Will default to
+ `self.transformer.config.motion_encoder_batch_size` if not set.
+ height (`int`, defaults to `720`):
+ The height of the generated video.
+ width (`int`, defaults to `1280`):
+ The width of the generated video.
+ segment_frame_length (`int`, defaults to `77`):
+ The number of frames in each generated video segment. The total frames of video generated will be equal
+ to the number of frames in `pose_video`; we will generate the video in segments until we have hit this
+ length. In general, should be 4N + 1, where N is a non-negative integer.
+ num_inference_steps (`int`, defaults to `20`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `1.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan
+ Animate inference.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ pose_video,
+ face_video,
+ background_video,
+ mask_video,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ mode,
+ prev_segment_conditioning_frames,
+ )
+
+ if segment_frame_length % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`segment_frame_length - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the"
+ f" nearest number."
+ )
+ segment_frame_length = (
+ segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ )
+ segment_frame_length = max(segment_frame_length, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # As we generate in segments of `segment_frame_length`, set the target frame length to be the least multiple
+ # of the effective segment length greater than or equal to the length of `pose_video`.
+ cond_video_frames = len(pose_video)
+ effective_segment_length = segment_frame_length - prev_segment_conditioning_frames
+ last_segment_frames = (cond_video_frames - prev_segment_conditioning_frames) % effective_segment_length
+ if last_segment_frames == 0:
+ num_padding_frames = 0
+ else:
+ num_padding_frames = effective_segment_length - last_segment_frames
+ num_target_frames = cond_video_frames + num_padding_frames
+ num_segments = num_target_frames // effective_segment_length
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Preprocess and encode the reference (character) image
+ image_height, image_width = self.video_processor.get_default_height_width(image)
+ if image_height != height or image_width != width:
+ logger.warning(f"Reshaping reference image from ({image_width}, {image_height}) to ({width}, {height})")
+ image_pixels = self.vae_image_processor.preprocess(image, height=height, width=width, resize_mode="fill").to(
+ device, dtype=torch.float32
+ )
+
+ # Get CLIP features from the reference image
+ if image_embeds is None:
+ image_embeds = self.encode_image(image, device)
+ image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 5. Encode conditioning videos (pose, face)
+ pose_video = self.pad_video_frames(pose_video, num_target_frames)
+ face_video = self.pad_video_frames(face_video, num_target_frames)
+
+ # TODO: also support np.ndarray input (e.g. from decord like the original implementation?)
+ pose_video_width, pose_video_height = pose_video[0].size
+ if pose_video_height != height or pose_video_width != width:
+ logger.warning(
+ f"Reshaping pose video from ({pose_video_width}, {pose_video_height}) to ({width}, {height})"
+ )
+ pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ face_video_width, face_video_height = face_video[0].size
+ expected_face_size = self.transformer.config.motion_encoder_size
+ if face_video_width != expected_face_size or face_video_height != expected_face_size:
+ logger.warning(
+ f"Reshaping face video from ({face_video_width}, {face_video_height}) to ({expected_face_size},"
+ f" {expected_face_size})"
+ )
+ face_video = self.video_processor.preprocess_video(
+ face_video, height=expected_face_size, width=expected_face_size
+ ).to(device, dtype=torch.float32)
+
+ if mode == "replace":
+ background_video = self.pad_video_frames(background_video, num_target_frames)
+ mask_video = self.pad_video_frames(mask_video, num_target_frames)
+
+ background_video = self.video_processor.preprocess_video(background_video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+ mask_video = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ # 6. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 7. Prepare latent variables which stay constant for all inference segments
+ num_channels_latents = self.vae.config.z_dim
+
+ # Get VAE-encoded latents of the reference (character) image
+ reference_image_latents = self.prepare_reference_image_latents(
+ image_pixels, batch_size * num_videos_per_prompt, generator=generator, device=device
+ )
+
+ # 8. Loop over video inference segments
+ start = 0
+ end = segment_frame_length # Data space frames, not latent frames
+ all_out_frames = []
+ out_frames = None
+
+ for _ in range(num_segments):
+ assert start + prev_segment_conditioning_frames < cond_video_frames
+
+ # Sample noisy latents from prior for the current inference segment
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ num_frames=segment_frame_length,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=latents if start == 0 else None, # Only use pre-calculated latents for first segment
+ )
+
+ pose_video_segment = pose_video[:, :, start:end]
+ face_video_segment = face_video[:, :, start:end]
+
+ face_video_segment = face_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1)
+ face_video_segment = face_video_segment.to(dtype=transformer_dtype)
+
+ if start > 0:
+ prev_segment_cond_video = out_frames[:, :, -prev_segment_conditioning_frames:].clone().detach()
+ else:
+ prev_segment_cond_video = None
+
+ if mode == "replace":
+ background_video_segment = background_video[:, :, start:end]
+ mask_video_segment = mask_video[:, :, start:end]
+
+ background_video_segment = background_video_segment.expand(
+ batch_size * num_videos_per_prompt, -1, -1, -1, -1
+ )
+ mask_video_segment = mask_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1)
+ else:
+ background_video_segment = None
+ mask_video_segment = None
+
+ pose_latents = self.prepare_pose_latents(
+ pose_video_segment, batch_size * num_videos_per_prompt, generator=generator, device=device
+ )
+ pose_latents = pose_latents.to(dtype=transformer_dtype)
+
+ prev_segment_cond_latents = self.prepare_prev_segment_cond_latents(
+ prev_segment_cond_video,
+ background_video=background_video_segment,
+ mask_video=mask_video_segment,
+ batch_size=batch_size * num_videos_per_prompt,
+ segment_frame_length=segment_frame_length,
+ start_frame=start,
+ height=height,
+ width=width,
+ prev_segment_cond_frames=prev_segment_conditioning_frames,
+ task=mode,
+ generator=generator,
+ device=device,
+ )
+
+ # Concatenate the reference latents in the frame dimension
+ reference_latents = torch.cat([reference_image_latents, prev_segment_cond_latents], dim=2)
+
+ # 8.1 Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ # Concatenate the reference image + prev segment conditioning in the channel dim
+ latent_model_input = torch.cat([latents, reference_latents], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ pose_hidden_states=pose_latents,
+ face_pixel_values=face_video_segment,
+ motion_encode_batch_size=motion_encode_batch_size,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ # Blank out face for unconditional guidance (set all pixels to -1)
+ face_pixel_values_uncond = face_video_segment * 0 - 1
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ pose_hidden_states=pose_latents,
+ face_pixel_values=face_pixel_values_uncond,
+ motion_encode_batch_size=motion_encode_batch_size,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ latents = latents.to(self.vae.dtype)
+ # Destandardize latents in preparation for Wan VAE decoding
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
+ 1, self.vae.config.z_dim, 1, 1, 1
+ ).to(latents.device, latents.dtype)
+ latents = latents / latents_recip_std + latents_mean
+ # Skip the first latent frame (used for conditioning)
+ out_frames = self.vae.decode(latents[:, :, 1:], return_dict=False)[0]
+
+ if start > 0:
+ out_frames = out_frames[:, :, prev_segment_conditioning_frames:]
+ all_out_frames.append(out_frames)
+
+ start += effective_segment_length
+ end += effective_segment_length
+
+ # Reset scheduler timesteps / state for next denoising loop
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ self._current_timestep = None
+ assert start + prev_segment_conditioning_frames >= cond_video_frames
+
+ if not output_type == "latent":
+ video = torch.cat(all_out_frames, dim=2)[:, :, :cond_video_frames]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py
new file mode 100644
index 000000000000..6f78cec07442
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py
@@ -0,0 +1,1054 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import regex as re
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from transformers import AutoTokenizer, UMT5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor
+
+from ...audio_processor import PipelineAudioInput
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanS2VTransformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, load_video, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import numpy as np, math, requests
+ >>> import torch
+ >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline
+ >>> from diffusers.utils import export_to_video, load_audio, export_to_merged_video_audio
+ >>> from transformers import Wav2Vec2ForCTC
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers"
+ >>> audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = WanSpeechToVideoPipeline.from_pretrained(
+ ... model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> headers = {"User-Agent": "Mozilla/5.0"}
+ >>> url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg"
+ >>> resp = requests.get(url, headers=headers, timeout=30)
+ >>> image = Image.open(BytesIO(resp.content))
+
+ >>> audio, sampling_rate = load_audio(
+ ... "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3"
+ ... )
+ >>> # pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4"
+
+
+ >>> def get_size_less_than_area(height, width, target_area=1024 * 704, divisor=64):
+ ... if height * width <= target_area:
+ ... # If the original image area is already less than or equal to the target,
+ ... # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
+ ... max_upper_area = target_area
+ ... min_scale = 0.1
+ ... max_scale = 1.0
+ ... else:
+ ... # Resize to fit within the target area and then pad to multiples of `divisor`
+ ... max_upper_area = target_area # Maximum allowed total pixel count after padding
+ ... d = divisor - 1
+ ... b = d * (height + width)
+ ... a = height * width
+ ... c = d**2 - max_upper_area
+
+ ... # Calculate scale boundaries using quadratic equation
+ ... min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied
+ ... max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding
+
+ ... # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
+ ... # Use binary search-like iteration to find this scale
+ ... find_it = False
+ ... for i in range(100):
+ ... scale = max_scale - (max_scale - min_scale) * i / 100
+ ... new_height, new_width = int(height * scale), int(width * scale)
+
+ ... # Pad to make dimensions divisible by 64
+ ... pad_height = (64 - new_height % 64) % 64
+ ... pad_width = (64 - new_width % 64) % 64
+ ... pad_top = pad_height // 2
+ ... pad_bottom = pad_height - pad_top
+ ... pad_left = pad_width // 2
+ ... pad_right = pad_width - pad_left
+
+ ... padded_height, padded_width = new_height + pad_height, new_width + pad_width
+
+ ... if padded_height * padded_width <= max_upper_area:
+ ... find_it = True
+ ... break
+
+ ... if find_it:
+ ... return padded_height, padded_width
+ ... else:
+ ... # Fallback: calculate target dimensions based on aspect ratio and divisor alignment
+ ... aspect_ratio = width / height
+ ... target_width = int((target_area * aspect_ratio) ** 0.5 // divisor * divisor)
+ ... target_height = int((target_area / aspect_ratio) ** 0.5 // divisor * divisor)
+
+ ... # Ensure the result is not larger than the original resolution
+ ... if target_width >= width or target_height >= height:
+ ... target_width = int(width // divisor * divisor)
+ ... target_height = int(height // divisor * divisor)
+
+ ... return target_height, target_width
+
+
+ >>> height, width = get_size_less_than_area(image.height, image.width, target_area=480 * 832)
+
+ >>> prompt = "Einstein singing a song."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... image=image,
+ ... audio=audio,
+ ... sampling_rate=sampling_rate,
+ ... height=height,
+ ... width=width,
+ ... num_frames_per_chunk=80,
+ ... # pose_video_path_or_url=pose_video_path_or_url,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+
+ >>> # Lastly, we need to merge the video and audio into a new video, with the duration set to
+ >>> # the shorter of the two and overwrite the original video file.
+
+ >>> export_to_merged_video_audio("output.mp4", "audio.mp3")
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):
+ required_duration = num_sample / target_fps
+ required_origin_frames = int(np.ceil(required_duration * original_fps))
+ if required_duration > total_frames / original_fps:
+ raise ValueError("required_duration must be less than video length")
+
+ if fixed_start is not None and fixed_start >= 0:
+ start_frame = fixed_start
+ else:
+ max_start = total_frames - required_origin_frames
+ if max_start < 0:
+ raise ValueError("video length is too short")
+ start_frame = np.random.randint(0, max_start + 1)
+ start_time = start_frame / original_fps
+
+ end_time = start_time + required_duration
+ time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
+
+ frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
+ frame_indices = np.clip(frame_indices, 0, total_frames - 1)
+ return frame_indices
+
+
+def linear_interpolation(features, input_fps, output_fps, output_len=None):
+ """
+ Args:
+ features: shape=[1, T, 512]
+ input_fps: fps for audio, f_a
+ output_fps: fps for video, f_m
+ output_len: video length
+ """
+ features = features.transpose(1, 2) # [1, 512, T]
+ seq_len = features.shape[2] / float(input_fps) # T/f_a
+ output_len = int(seq_len * output_fps) # f_m*T/f_a
+ output_features = F.interpolate(
+ features, size=output_len, align_corners=True, mode="linear"
+ ) # [1, 512, output_len]
+ return output_features.transpose(1, 2) # [1, output_len, 512]
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for prompt-image-audio-to-video generation using Wan2.2-S2V.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`WanT2VTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ audio_encoder ([`Wav2Vec2ForCTC`]):
+ Audio Encoder to process audio inputs.
+ audio_processor ([`Wav2Vec2Processor`]):
+ Audio Processor to preprocess audio inputs.
+ """
+
+ model_cpu_offload_seq = "text_encoder->audio_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ transformer: WanS2VTransformer3DModel,
+ audio_encoder: Wav2Vec2ForCTC,
+ audio_processor: Wav2Vec2Processor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ audio_encoder=audio_encoder,
+ audio_processor=audio_processor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, resample="bilinear")
+ self.audio_processor = audio_processor
+ self.motion_frames = 73
+ self.drop_first_motion = True
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_audio(
+ self,
+ audio: PipelineAudioInput,
+ sampling_rate: int,
+ num_frames: int,
+ fps: int = 16,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ video_rate = 30
+ audio_sample_m = 0
+
+ input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values
+
+ # retrieve logits & take argmax
+ res = self.audio_encoder(input_values.to(device), output_hidden_states=True)
+ feat = torch.cat(res.hidden_states)
+
+ feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
+
+ audio_embed = feat.to(torch.float32) # Encoding for the motion
+
+ num_layers, audio_frame_num, audio_dim = audio_embed.shape
+
+ if num_layers > 1:
+ return_all_layers = True
+ else:
+ return_all_layers = False
+
+ scale = video_rate / fps
+
+ num_repeat = int(audio_frame_num / (num_frames * scale)) + 1
+
+ bucket_num = num_repeat * num_frames
+ padd_audio_num = math.ceil(num_repeat * num_frames / fps * video_rate) - audio_frame_num
+ batch_idx = get_sample_indices(
+ original_fps=video_rate,
+ total_frames=audio_frame_num + padd_audio_num,
+ target_fps=fps,
+ num_sample=bucket_num,
+ fixed_start=0,
+ )
+ batch_audio_eb = []
+ audio_sample_stride = int(video_rate / fps)
+ for bi in batch_idx:
+ if bi < audio_frame_num:
+ chosen_idx = list(
+ range(
+ bi - audio_sample_m * audio_sample_stride,
+ bi + (audio_sample_m + 1) * audio_sample_stride,
+ audio_sample_stride,
+ )
+ )
+ chosen_idx = [0 if c < 0 else c for c in chosen_idx]
+ chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
+
+ if return_all_layers:
+ frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
+ else:
+ frame_audio_embed = audio_embed[0][chosen_idx].flatten()
+ else:
+ frame_audio_embed = (
+ torch.zeros([audio_dim * (2 * audio_sample_m + 1)], device=audio_embed.device)
+ if not return_all_layers
+ else torch.zeros([num_layers, audio_dim * (2 * audio_sample_m + 1)], device=audio_embed.device)
+ )
+ batch_audio_eb.append(frame_audio_embed)
+ audio_embed_bucket = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
+
+ audio_embed_bucket = audio_embed_bucket.to(device)
+ audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
+ if len(audio_embed_bucket.shape) == 3:
+ audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
+ elif len(audio_embed_bucket.shape) == 4:
+ audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
+ return audio_embed_bucket, num_repeat
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ audio=None,
+ audio_embeds=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if audio is not None and audio_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `audio`: {audio} and `audio_embeds`: {audio_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif audio is None and audio_embeds is None:
+ raise ValueError(
+ "Provide either `audio` or `audio_embeds`. Cannot leave both `audio` and `audio_embeds` undefined."
+ )
+ elif audio is not None and not isinstance(audio, (np.ndarray)):
+ raise ValueError(f"`audio` has to be of type `np.ndarray` but is {type(audio)}")
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ latent_motion_frames: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames_per_chunk: int = 80,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ pose_video: Optional[List[Image.Image]] = None,
+ init_first_frame: bool = False,
+ num_chunks: int = 1,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]]:
+ num_latent_frames = (
+ num_frames_per_chunk + 3 + self.motion_frames
+ ) // self.vae_scale_factor_temporal - latent_motion_frames
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ if image is not None:
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
+
+ video_condition = image.to(device=device, dtype=self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+
+ latent_condition = latent_condition.to(dtype)
+ latent_condition = (latent_condition - latents_mean) * latents_std
+
+ motion_pixels = torch.zeros([1, 3, self.motion_frames, height, width], dtype=self.vae.dtype, device=device)
+ # Get pose condition input if needed
+ pose_condition = self.load_pose_condition(
+ pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std
+ )
+ # Encode motion latents
+ videos_last_pixels = motion_pixels.detach()
+ if init_first_frame:
+ self.drop_first_motion = False
+ motion_pixels[:, :, -6:] = video_condition
+ motion_latents = retrieve_latents(self.vae.encode(motion_pixels), sample_mode="argmax")
+ motion_latents = (motion_latents - latents_mean) * latents_std
+
+ return latents, latent_condition, videos_last_pixels, motion_latents, pose_condition
+ else:
+ return latents
+
+ def load_pose_condition(
+ self, pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std
+ ):
+ device = self._execution_device
+ dtype = self.vae.dtype
+ if pose_video is not None:
+ padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2]
+ pose_video = pose_video.to(dtype=dtype, device=device)
+ pose_video = torch.cat(
+ [
+ pose_video,
+ -torch.ones([1, 3, padding_frame_num, height, width], dtype=dtype, device=device),
+ ],
+ dim=2,
+ )
+
+ pose_video = torch.chunk(pose_video, num_chunks, dim=2)
+ else:
+ pose_video = [-torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=dtype, device=device)]
+
+ # Vectorized processing: concatenate all chunks along batch dimension
+ all_poses = torch.cat(
+ [torch.cat([cond[:, :, 0:1], cond], dim=2) for cond in pose_video], dim=0
+ ) # Shape: [num_chunks, 3, num_frames_per_chunk+1, height, width]
+
+ pose_condition = retrieve_latents(self.vae.encode(all_poses), sample_mode="argmax")[:, :, 1:]
+ pose_condition = (pose_condition - latents_mean) * latents_std
+
+ return pose_condition
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ audio: PipelineAudioInput,
+ sampling_rate: int,
+ prompt: Union[str, List[str]],
+ negative_prompt: Union[str, List[str]] = None,
+ pose_video_path_or_url: Optional[str] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames_per_chunk: int = 80,
+ num_inference_steps: int = 40,
+ guidance_scale: float = 4.5,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ audio_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ init_first_frame: bool = False,
+ sampling_fps: int = 16,
+ num_chunks: Optional[int] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ audio (`PipelineAudioInput`):
+ The audio input to condition the generation on. Must be an audio, a list of audios or a `torch.Tensor`.
+ sampling_rate (`int`):
+ The sampling rate of the audio input.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ pose_video_path_or_url (`str` or `List[str]`, *optional*):
+ The path or URL to the pose video to condition the generation on.
+ height (`int`, defaults to `480`):
+ The height of the generated video.
+ width (`int`, defaults to `832`):
+ The width of the generated video.
+ num_frames_per_chunk (`int`, defaults to `80`):
+ The number of frames in each chunk of the generated video. `num_frames_per_chunk` should be a multiple
+ of 4.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ audio_embeds (`torch.Tensor`, *optional*):
+ Pre-generated audio embeddings. Can be used to easily tweak audio inputs (weighting). If not provided,
+ audio embeddings are generated from the `audio` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+ init_first_frame (`bool`, *optional*, defaults to False):
+ Whether to use the reference image as the first frame (i.e., standard image-to-video generation).
+ sampling_fps (`int`, *optional*, defaults to 16):
+ The frame rate (in frames per second) at which the generated video will be sampled.
+ num_chunks (`int`, *optional*, defaults to None):
+ The number of chunks to process. If not provided, the number of chunks will be determined by the audio
+ input to generate whole audio. E.g., If the input audio has 4 chunks, then user can set num_chunks=1 to
+ see 1 out of 4 chunks only without generating the whole video.
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ audio,
+ audio_embeds,
+ )
+
+ if num_frames_per_chunk % self.vae_scale_factor_temporal != 0:
+ num_frames_per_chunk = (
+ num_frames_per_chunk // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal
+ )
+ logger.warning(
+ f"`num_frames_per_chunk` had to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number: {num_frames_per_chunk}"
+ )
+ num_frames_per_chunk = max(num_frames_per_chunk, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ if audio_embeds is None:
+ audio_embeds, num_chunks_audio = self.encode_audio(
+ audio, sampling_rate, num_frames_per_chunk, sampling_fps, device
+ )
+ if num_chunks is None or num_chunks > num_chunks_audio:
+ num_chunks = num_chunks_audio
+ audio_embeds = audio_embeds.to(transformer_dtype)
+
+ latent_motion_frames = (self.motion_frames + 3) // self.vae_scale_factor_temporal
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ image = self.video_processor.preprocess(
+ image, height=height, width=width, resize_mode="resize_min_center_crop"
+ ).to(device, dtype=torch.float32)
+
+ pose_video = None
+ if pose_video_path_or_url is not None:
+ pose_video = load_video(
+ pose_video_path_or_url,
+ n_frames=num_frames_per_chunk * num_chunks,
+ target_fps=sampling_fps,
+ reverse=True,
+ )
+ pose_video = self.video_processor.preprocess_video(
+ pose_video, height=height, width=width, resize_mode="resize_min_center_crop"
+ ).to(device, dtype=torch.float32)
+
+ video_chunks = []
+ for r in range(num_chunks):
+ latents_outputs = self.prepare_latents(
+ image if r == 0 else None,
+ batch_size * num_videos_per_prompt,
+ latent_motion_frames,
+ num_channels_latents,
+ height,
+ width,
+ num_frames_per_chunk,
+ torch.float32,
+ device,
+ generator,
+ latents if r == 0 else None,
+ pose_video,
+ init_first_frame,
+ num_chunks,
+ )
+
+ if r == 0:
+ latents, condition, videos_last_pixels, motion_latents, pose_condition = latents_outputs
+ else:
+ latents = latents_outputs
+
+ with torch.no_grad():
+ left_idx = r * num_frames_per_chunk
+ right_idx = r * num_frames_per_chunk + num_frames_per_chunk
+ pose_latents = pose_condition[r] if pose_video is not None else pose_condition[0] * 0
+ pose_latents = pose_latents.to(dtype=transformer_dtype, device=device)
+ audio_embeds_input = audio_embeds[..., left_idx:right_idx]
+ motion_latents_input = motion_latents.to(transformer_dtype).clone()
+
+ # 4. Prepare timesteps by resetting scheduler in each chunk
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = latents.to(transformer_dtype)
+ condition = condition.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ motion_latents=motion_latents_input,
+ image_latents=condition,
+ pose_latents=pose_latents,
+ audio_embeds=audio_embeds_input,
+ motion_frames=[self.motion_frames, latent_motion_frames],
+ drop_motion_frames=self.drop_first_motion and r == 0,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ motion_latents=motion_latents_input,
+ image_latents=condition,
+ pose_latents=pose_latents,
+ audio_embeds=0.0 * audio_embeds_input,
+ motion_frames=[self.motion_frames, latent_motion_frames],
+ drop_motion_frames=self.drop_first_motion and r == 0,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not (self.drop_first_motion and r == 0):
+ decode_latents = torch.cat([motion_latents, latents], dim=2)
+ else:
+ decode_latents = torch.cat([condition, latents], dim=2)
+
+ decode_latents = decode_latents.to(self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(decode_latents.device, decode_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ decode_latents.device, decode_latents.dtype
+ )
+ decode_latents = decode_latents / latents_std + latents_mean
+ video = self.vae.decode(decode_latents, return_dict=False)[0]
+ video = video[:, :, -(num_frames_per_chunk):]
+
+ if self.drop_first_motion and r == 0:
+ video = video[:, :, 3:]
+
+ num_overlap_frames = min(self.motion_frames, video.shape[2])
+ videos_last_pixels = torch.cat(
+ [videos_last_pixels[:, :, num_overlap_frames:], video[:, :, -num_overlap_frames:]], dim=2
+ )
+
+ # Update motion_latents for next iteration
+ motion_latents = retrieve_latents(self.vae.encode(videos_last_pixels), sample_mode="argmax")
+ motion_latents = (motion_latents - latents_mean) * latents_std
+
+ video_chunks.append(video)
+
+ video_chunks = torch.cat(video_chunks, dim=2)
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ video = self.video_processor.postprocess_video(video_chunks, output_type=output_type)
+ else:
+ # TODO
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py
index 63e557a98fbe..351ae2e70563 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py
@@ -758,11 +758,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
diff --git a/src/diffusers/schedulers/scheduling_amused.py b/src/diffusers/schedulers/scheduling_amused.py
index 238b8d869171..a0b8fbc862b0 100644
--- a/src/diffusers/schedulers/scheduling_amused.py
+++ b/src/diffusers/schedulers/scheduling_amused.py
@@ -1,6 +1,6 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import torch
@@ -9,13 +9,48 @@
from .scheduling_utils import SchedulerMixin
-def gumbel_noise(t, generator=None):
+def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+ """
+ Generate Gumbel noise for sampling.
+
+ Args:
+ t (`torch.Tensor`):
+ Input tensor to match the shape and dtype of the output noise.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for reproducible sampling.
+
+ Returns:
+ `torch.Tensor`:
+ Gumbel-distributed noise with the same shape, dtype, and device as the input tensor.
+ """
device = generator.device if generator is not None else t.device
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
-def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
+def mask_by_random_topk(
+ mask_len: torch.Tensor,
+ probs: torch.Tensor,
+ temperature: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+) -> torch.Tensor:
+ """
+ Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness.
+
+ Args:
+ mask_len (`torch.Tensor`):
+ Number of tokens to mask per sample in the batch.
+ probs (`torch.Tensor`):
+ Probability scores for each token.
+ temperature (`float`, *optional*, defaults to 1.0):
+ Temperature parameter for controlling randomness in the masking process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for reproducible sampling.
+
+ Returns:
+ `torch.Tensor`:
+ Boolean mask indicating which tokens should be masked.
+ """
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
sorted_confidence = torch.sort(confidence, dim=-1).values
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
@@ -29,28 +64,46 @@ class AmusedSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
- prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
- denoising loop.
- pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
- The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
- `pred_original_sample` can be used to preview progress or for guidance.
+ prev_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`):
+ Computed sample `(x_{t-1})` of previous timestep with token IDs. `prev_sample` should be used as next model
+ input in the denoising loop.
+ pred_original_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`, *optional*):
+ The predicted fully denoised sample `(x_{0})` with token IDs based on the model output from the current
+ timestep. `pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
- pred_original_sample: torch.Tensor = None
+ pred_original_sample: Optional[torch.Tensor] = None
class AmusedScheduler(SchedulerMixin, ConfigMixin):
+ """
+ A scheduler for masked token generation as used in [`AmusedPipeline`].
+
+ This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear
+ schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates
+ on discrete token IDs, making it suitable for autoregressive and non-autoregressive masked token generation models.
+
+ This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the
+ generic methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ mask_token_id (`int`):
+ The token ID used to represent masked tokens in the sequence.
+ masking_schedule (`Literal["cosine", "linear"]`, *optional*, defaults to `"cosine"`):
+ The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or `"linear"`.
+ """
+
order = 1
- temperatures: torch.Tensor
+ temperatures: Optional[torch.Tensor]
+ timesteps: Optional[torch.Tensor]
@register_to_config
def __init__(
self,
mask_token_id: int,
- masking_schedule: str = "cosine",
+ masking_schedule: Literal["cosine", "linear"] = "cosine",
):
self.temperatures = None
self.timesteps = None
@@ -58,9 +111,23 @@ def __init__(
def set_timesteps(
self,
num_inference_steps: int,
- temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
- device: Union[str, torch.device] = None,
- ):
+ temperature: Union[float, Tuple[float, float], List[float]] = (2, 0),
+ device: Optional[Union[str, torch.device]] = None,
+ ) -> None:
+ """
+ Set the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to `(2, 0)`):
+ Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided,
+ temperatures will be linearly interpolated between the first and second values across all timesteps. If
+ a single value is provided, temperatures will be linearly interpolated from that value to 0.01.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps and temperatures should be moved to. If `None`, the timesteps are not
+ moved.
+ """
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
if isinstance(temperature, (tuple, list)):
@@ -71,12 +138,38 @@ def set_timesteps(
def step(
self,
model_output: torch.Tensor,
- timestep: torch.long,
+ timestep: int,
sample: torch.LongTensor,
- starting_mask_ratio: int = 1,
+ starting_mask_ratio: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
- ) -> Union[AmusedSchedulerOutput, Tuple]:
+ ) -> Union[AmusedSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Predict the sample at the previous timestep by masking tokens based on confidence scores.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens,
+ codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.LongTensor`):
+ A current instance of a sample created by the diffusion process. Contains token IDs, with masked
+ positions indicated by `mask_token_id`.
+ starting_mask_ratio (`float`, *optional*, defaults to 1.0):
+ A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being
+ masked at each step.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for reproducible sampling.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple.
+
+ Returns:
+ [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] is returned,
+ otherwise a tuple is returned where the first element is the sample tensor (`prev_sample`) and the
+ second element is the predicted original sample tensor (`pred_original_sample`).
+ """
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
if two_dim_input:
@@ -137,7 +230,27 @@ def step(
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
- def add_noise(self, sample, timesteps, generator=None):
+ def add_noise(
+ self,
+ sample: torch.LongTensor,
+ timesteps: int,
+ generator: Optional[torch.Generator] = None,
+ ) -> torch.LongTensor:
+ """
+ Add noise to a sample by randomly masking tokens according to the masking schedule.
+
+ Args:
+ sample (`torch.LongTensor`):
+ The input sample containing token IDs to be partially masked.
+ timesteps (`int`):
+ The timestep that determines how much masking to apply. Higher timesteps result in more masking.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for reproducible masking.
+
+ Returns:
+ `torch.LongTensor`:
+ The sample with some tokens replaced by `mask_token_id` according to the masking schedule.
+ """
step_idx = (self.timesteps == timesteps).nonzero()
ratio = (step_idx + 1) / len(self.timesteps)
diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py
index d7af018b284a..767fa9157f59 100644
--- a/src/diffusers/schedulers/scheduling_consistency_decoder.py
+++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py
@@ -1,6 +1,6 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Literal, Optional, Tuple, Union
import torch
@@ -12,10 +12,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -23,16 +23,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py
index 5d81d5eb8ac0..386a43db0f9c 100644
--- a/src/diffusers/schedulers/scheduling_consistency_models.py
+++ b/src/diffusers/schedulers/scheduling_consistency_models.py
@@ -121,7 +121,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -287,7 +287,23 @@ def get_scalings_for_boundary_condition(self, sigma):
return c_skip, c_out
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -302,7 +318,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -410,6 +433,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
index b9567f2c47d5..7b11d704932b 100644
--- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
@@ -137,7 +137,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -266,6 +266,19 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -537,6 +550,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py
index 5ee0d084f060..d7fe29a72ac9 100644
--- a/src/diffusers/schedulers/scheduling_ddim.py
+++ b/src/diffusers/schedulers/scheduling_ddim.py
@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -92,17 +93,17 @@ def alpha_bar_fn(t):
return torch.tensor(betas, dtype=torch.float32)
-def rescale_zero_terminal_snr(betas):
+def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -143,9 +144,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
+ of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
clip_sample (`bool`, defaults to `True`):
@@ -158,10 +159,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`):
+ Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
+ process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -169,9 +170,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
- timestep_spacing (`str`, defaults to `"leading"`):
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
+ Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891) for more information.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -187,17 +189,17 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -250,7 +252,25 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
"""
return sample
- def _get_variance(self, timestep, prev_timestep):
+ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
+ """
+ Computes the variance of the noise added at a given diffusion step.
+
+ For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
+ literature:
+ var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+ where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
+
+ Args:
+ timestep (`int`):
+ The current timestep in the diffusion process.
+ prev_timestep (`int`):
+ The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
+
+ Returns:
+ `torch.Tensor`:
+ The variance for the current timestep.
+ """
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -263,6 +283,8 @@ def _get_variance(self, timestep, prev_timestep):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -270,6 +292,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -294,13 +324,18 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
return sample
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`Union[str, torch.device]`, *optional*):
+ The device to use for the timesteps.
+
+ Raises:
+ ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
"""
if num_inference_steps > self.config.num_train_timesteps:
@@ -346,7 +381,7 @@ def step(
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
- generator=None,
+ generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
@@ -357,20 +392,21 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`float`):
+ timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- eta (`float`):
- The weight of noise for added noise in diffusion step.
- use_clipped_model_output (`bool`, defaults to `False`):
+ eta (`float`, *optional*, defaults to 0.0):
+ The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic)
+ and 1 corresponds to DDPM (fully stochastic).
+ use_clipped_model_output (`bool`, *optional*, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
- A random number generator.
- variance_noise (`torch.Tensor`):
+ A random number generator for reproducible sampling.
+ variance_noise (`torch.Tensor`, *optional*):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
@@ -477,6 +513,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -499,6 +551,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -517,5 +584,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
index c19efdc7834d..acb5a5f3e522 100644
--- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
+++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
@@ -18,7 +18,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -408,6 +409,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -430,6 +447,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py
index 49dba840d089..a7717940e2a1 100644
--- a/src/diffusers/schedulers/scheduling_ddim_inverse.py
+++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py
@@ -16,7 +16,7 @@
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -47,10 +47,10 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -58,16 +58,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -95,13 +96,13 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py
index 7c3f03a8dbe1..d957ade901b3 100644
--- a/src/diffusers/schedulers/scheduling_ddim_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py
@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class DDIMParallelSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -97,13 +98,13 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -194,17 +195,17 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -285,6 +286,8 @@ def _batch_get_variance(self, t, prev_t):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -292,6 +295,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -324,6 +335,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`Union[str, torch.device]`, *optional*):
+ The device to use for the timesteps.
+
+ Raises:
+ ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
"""
if num_inference_steps > self.config.num_train_timesteps:
@@ -602,6 +618,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -624,6 +656,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index 0fab6d910a82..1d0ad49c58cd 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -46,10 +46,10 @@ class DDPMSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -57,16 +57,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -90,17 +91,17 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
-def rescale_zero_terminal_snr(betas):
+def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -134,39 +135,37 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- num_train_timesteps (`int`, defaults to 1000):
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
+ beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
+ beta_end (`float`, defaults to `0.02`):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`.
+ beta_schedule (`"linear"`, `"scaled_linear"`, `"squaredcos_cap_v2"`, or `"sigmoid"`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`.
- variance_type (`str`, defaults to `"fixed_small"`):
- Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`,
- `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, defaults to `"fixed_small"`):
+ Clip the variance when adding noise to the denoised sample.
clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability.
- clip_sample_range (`float`, defaults to 1.0):
+ clip_sample_range (`float`, defaults to `1.0`):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ dynamic_thresholding_ratio (`float`, defaults to `0.995`):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
- sample_max_value (`float`, defaults to 1.0):
+ sample_max_value (`float`, defaults to `1.0`):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
- timestep_spacing (`str`, defaults to `"leading"`):
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- steps_offset (`int`, defaults to 0):
+ steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
@@ -183,16 +182,18 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- variance_type: str = "fixed_small",
+ variance_type: Literal[
+ "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"
+ ] = "fixed_small",
clip_sample: bool = True,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
@@ -322,7 +323,31 @@ def set_timesteps(
self.timesteps = torch.from_numpy(timesteps).to(device)
- def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ def _get_variance(
+ self,
+ t: int,
+ predicted_variance: Optional[torch.Tensor] = None,
+ variance_type: Optional[
+ Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"]
+ ] = None,
+ ) -> torch.Tensor:
+ """
+ Compute the variance for a given timestep according to the specified variance type.
+
+ Args:
+ t (`int`):
+ The current timestep.
+ predicted_variance (`torch.Tensor`, *optional*):
+ The predicted variance from the model. Used only when `variance_type` is `"learned"` or
+ `"learned_range"`.
+ variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, *optional*):
+ The type of variance to compute. If `None`, uses the variance type specified in the scheduler
+ configuration.
+
+ Returns:
+ `torch.Tensor`:
+ The computed variance.
+ """
prev_t = self.previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
@@ -364,6 +389,8 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -371,6 +398,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -400,7 +435,7 @@ def step(
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
- generator=None,
+ generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
@@ -410,20 +445,19 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`float`):
+ timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
- return_dict (`bool`, *optional*, defaults to `True`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
-
"""
t = timestep
@@ -504,6 +538,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -525,6 +575,21 @@ def add_noise(
return noisy_samples
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -543,10 +608,21 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
- def previous_timestep(self, timestep):
+ def previous_timestep(self, timestep: int) -> int:
+ """
+ Compute the previous timestep in the diffusion chain.
+
+ Args:
+ timestep (`int`):
+ The current timestep.
+
+ Returns:
+ `int`:
+ The previous timestep.
+ """
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index ec741f9ecb7d..78011d0e46a1 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -48,10 +48,10 @@ class DDPMParallelSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -59,16 +59,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -96,13 +97,13 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -191,16 +192,18 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- variance_type: str = "fixed_small",
+ variance_type: Literal[
+ "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"
+ ] = "fixed_small",
clip_sample: bool = True,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
@@ -333,7 +336,31 @@ def set_timesteps(
self.timesteps = torch.from_numpy(timesteps).to(device)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance
- def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ def _get_variance(
+ self,
+ t: int,
+ predicted_variance: Optional[torch.Tensor] = None,
+ variance_type: Optional[
+ Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"]
+ ] = None,
+ ) -> torch.Tensor:
+ """
+ Compute the variance for a given timestep according to the specified variance type.
+
+ Args:
+ t (`int`):
+ The current timestep.
+ predicted_variance (`torch.Tensor`, *optional*):
+ The predicted variance from the model. Used only when `variance_type` is `"learned"` or
+ `"learned_range"`.
+ variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, *optional*):
+ The type of variance to compute. If `None`, uses the variance type specified in the scheduler
+ configuration.
+
+ Returns:
+ `torch.Tensor`:
+ The computed variance.
+ """
prev_t = self.previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
@@ -376,6 +403,8 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -383,6 +412,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -593,6 +630,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -615,6 +668,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -638,6 +706,17 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
+ """
+ Compute the previous timestep in the diffusion chain.
+
+ Args:
+ timestep (`int`):
+ The current timestep.
+
+ Returns:
+ `int`:
+ The previous timestep.
+ """
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index 7d8685ba10c3..bf8e1d98d6c0 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -16,7 +16,7 @@
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,10 +32,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -43,16 +43,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -229,7 +230,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -320,6 +321,8 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -327,6 +330,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -353,6 +364,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -388,7 +412,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -414,7 +451,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -438,7 +487,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
index f7b63720e107..c5d79b5fe54a 100644
--- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
+++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
@@ -18,7 +18,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -50,10 +50,10 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -61,16 +61,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -445,6 +446,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -467,6 +484,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index 8b523cd13f1f..dee97f39ff68 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,10 +32,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -43,16 +43,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -80,13 +81,13 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -323,7 +324,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -460,6 +461,8 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -467,6 +470,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -493,6 +504,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -527,7 +551,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -566,7 +603,19 @@ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -590,7 +639,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index f1a1ac3d8216..0f734aeb54c9 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,10 +32,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -43,16 +43,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -332,6 +333,8 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -339,6 +342,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -365,6 +376,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -400,7 +424,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -426,7 +463,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -450,7 +499,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
index eeb06773d977..ef89feb1cad6 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -115,10 +115,10 @@ def __call__(self, sigma, sigma_next):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -126,16 +126,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -250,7 +251,23 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -265,7 +282,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -301,7 +325,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -429,6 +453,19 @@ def t_fn(_sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -451,9 +488,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t
- # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ # Copied from diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
@@ -467,7 +515,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -491,7 +551,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -645,6 +722,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 1ae824973034..0b271d7eacb4 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -34,10 +34,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -45,16 +45,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -294,7 +295,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -410,6 +411,8 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -417,6 +420,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -443,6 +454,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -478,7 +502,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -504,7 +541,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -528,7 +577,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
index e9ba695e1f39..eeec588e27a3 100644
--- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
@@ -169,7 +169,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -299,6 +299,8 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -306,6 +308,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -332,6 +342,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -672,6 +695,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py
index dbeff3de5652..0bf17356a7fa 100644
--- a/src/diffusers/schedulers/scheduling_edm_euler.py
+++ b/src/diffusers/schedulers/scheduling_edm_euler.py
@@ -155,7 +155,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -284,7 +284,23 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -299,7 +315,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -413,6 +436,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
index 9cdaa2c5e101..8f39507301ce 100644
--- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -97,13 +98,13 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -245,7 +246,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -319,7 +320,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -334,7 +351,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -451,6 +475,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index f58d918dbfbe..5ea926c4ca38 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -52,10 +52,10 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -63,16 +63,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -96,17 +97,17 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
-def rescale_zero_terminal_snr(betas):
+def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -146,17 +147,17 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
+ beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear` or `scaled_linear`.
+ `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`, *optional*):
+ Prediction type of the scheduler function; can be `"epsilon"` (predicts the noise of the diffusion
+ process), `"sample"` (directly predicts the noisy sample`) or `"v_prediction"` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
- interpolation_type(`str`, defaults to `"linear"`, *optional*):
- The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
+ interpolation_type (`Literal["linear", "log_linear"]`, defaults to `"linear"`, *optional*):
+ The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
`"linear"` or `"log_linear"`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
@@ -166,18 +167,26 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
- timestep_spacing (`str`, defaults to `"linspace"`):
+ sigma_min (`float`, *optional*):
+ The minimum sigma value for the noise schedule. If not provided, defaults to the last sigma in the
+ schedule.
+ sigma_max (`float`, *optional*):
+ The maximum sigma value for the noise schedule. If not provided, defaults to the first sigma in the
+ schedule.
+ timestep_spacing (`Literal["linspace", "leading", "trailing"]`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ timestep_type (`Literal["discrete", "continuous"]`, defaults to `"discrete"`):
+ The type of timesteps to use. Can be `"discrete"` or `"continuous"`.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
- final_sigmas_type (`str`, defaults to `"zero"`):
+ final_sigmas_type (`Literal["zero", "sigma_min"]`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -189,20 +198,20 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- prediction_type: str = "epsilon",
- interpolation_type: str = "linear",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
+ interpolation_type: Literal["linear", "log_linear"] = "linear",
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
- timestep_spacing: str = "linspace",
- timestep_type: str = "discrete", # can be "discrete" or "continuous"
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
+ timestep_type: Literal["discrete", "continuous"] = "discrete",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
- final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
+ final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -259,8 +268,15 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def init_noise_sigma(self):
- # standard deviation of the initial noise distribution
+ def init_noise_sigma(self) -> Union[float, torch.Tensor]:
+ """
+ The standard deviation of the initial noise distribution.
+
+ Returns:
+ `float` or `torch.Tensor`:
+ The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
+ the timestep spacing configuration.
+ """
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
if self.config.timestep_spacing in ["linspace", "trailing"]:
return max_sigma
@@ -268,26 +284,34 @@ def init_noise_sigma(self):
return (max_sigma**2 + 1) ** 0.5
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
- The index counter for current timestep. It will increase 1 after each scheduler step.
+ The index counter for current timestep. It will increase by 1 after each scheduler step.
+
+ Returns:
+ `int` or `None`:
+ The current step index, or `None` if not initialized.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+
+ Returns:
+ `int` or `None`:
+ The begin index for the scheduler, or `None` if not set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -299,13 +323,13 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
Args:
sample (`torch.Tensor`):
- The input sample.
- timestep (`int`, *optional*):
+ The input sample to be scaled.
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
- A scaled input sample.
+ A scaled input sample, divided by `(sigma**2 + 1) ** 0.5`.
"""
if self.step_index is None:
self._init_step_index(timestep)
@@ -318,17 +342,18 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
def set_timesteps(
self,
- num_inference_steps: int = None,
- device: Union[str, torch.device] = None,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
- ):
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model.
+ num_inference_steps (`int`, *optional*):
+ The number of diffusion steps used when generating samples with a pre-trained model. If `None`,
+ `timesteps` or `sigmas` must be provided.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -336,10 +361,9 @@ def set_timesteps(
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
sigmas (`List[float]`, *optional*):
- Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
- will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
- `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
- custom sigmas schedule.
+ Custom sigmas used to support arbitrary timesteps schedule. If `None`, timesteps and sigmas will be
+ generated based on the relevant scheduler attributes. If `sigmas` is passed, `num_inference_steps` and
+ `timesteps` must be `None`, and the timesteps will be generated based on the custom sigmas schedule.
"""
if timesteps is not None and sigmas is not None:
@@ -449,7 +473,20 @@ def set_timesteps(
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -473,8 +510,21 @@ def _sigma_to_t(self, sigma, log_sigmas):
return t
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
- def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -500,7 +550,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -523,7 +585,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -551,7 +630,23 @@ def _convert_to_beta(
)
return sigmas
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -565,7 +660,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -591,26 +693,33 @@ def step(
Args:
model_output (`torch.Tensor`):
- The direct output from learned diffusion model.
- timestep (`float`):
+ The direct output from the learned diffusion model.
+ timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- s_churn (`float`):
- s_tmin (`float`):
- s_tmax (`float`):
- s_noise (`float`, defaults to 1.0):
+ s_churn (`float`, *optional*, defaults to `0.0`):
+ Stochasticity parameter that controls the amount of noise added during sampling. Higher values increase
+ randomness.
+ s_tmin (`float`, *optional*, defaults to `0.0`):
+ Minimum timestep threshold for applying stochasticity. Only timesteps above this value will have noise
+ added.
+ s_tmax (`float`, *optional*, defaults to `inf`):
+ Maximum timestep threshold for applying stochasticity. Only timesteps below this value will have noise
+ added.
+ s_noise (`float`, *optional*, defaults to `1.0`):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
- A random number generator.
- return_dict (`bool`):
+ A random number generator for reproducible sampling.
+ return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
- returned, otherwise a tuple is returned where the first element is the sample tensor.
+ If `return_dict` is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor and the second
+ element is the predicted original sample.
"""
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
@@ -689,6 +798,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -717,6 +841,24 @@ def add_noise(
return noisy_samples
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction for the given sample and noise at the specified timesteps.
+
+ This method implements the velocity prediction used in v-prediction models, which predicts a linear combination
+ of the sample and noise.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample for which to compute the velocity.
+ noise (`torch.Tensor`):
+ The noise tensor corresponding to the sample.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to compute the velocity.
+
+ Returns:
+ `torch.Tensor`:
+ The velocity prediction computed as `sqrt(alpha_prod) * noise - sqrt(1 - alpha_prod) * sample`.
+ """
if (
isinstance(timesteps, int)
or isinstance(timesteps, torch.IntTensor)
@@ -753,5 +895,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 1a4f12ddfa53..9fd61d9e18d1 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -160,7 +160,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -473,7 +473,20 @@ def step(
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -499,7 +512,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -523,7 +548,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
index 38e5f1ba77a8..6febee444c5a 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
@@ -102,7 +102,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
diff --git a/src/diffusers/schedulers/scheduling_flow_match_lcm.py b/src/diffusers/schedulers/scheduling_flow_match_lcm.py
index 933bb1cf8e3d..25186d1fe969 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_lcm.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_lcm.py
@@ -168,7 +168,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -473,7 +473,20 @@ def step(
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -499,7 +512,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -523,7 +548,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py
index bd1239cfaec7..930b0344646d 100644
--- a/src/diffusers/schedulers/scheduling_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_heun_discrete.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class HeunDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -187,7 +188,23 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -229,7 +246,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -354,6 +371,19 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -378,7 +408,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -404,7 +447,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -428,7 +483,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -461,7 +533,14 @@ def state_in_first_order(self):
return self.dt is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -579,6 +658,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py
index 23bc21f10ca4..da188fe8297c 100644
--- a/src/diffusers/schedulers/scheduling_ipndm.py
+++ b/src/diffusers/schedulers/scheduling_ipndm.py
@@ -78,7 +78,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -112,7 +112,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -127,7 +143,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
index 6588464073a1..595b93c39d4c 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -50,10 +50,10 @@ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -61,16 +61,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -206,7 +207,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -342,6 +343,19 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -366,7 +380,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -392,7 +419,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -416,7 +455,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -449,7 +505,23 @@ def state_in_first_order(self):
return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -464,7 +536,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -586,6 +665,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
index 9b4cd4e204d6..7db12227229e 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class KDPM2DiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -206,7 +207,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -330,7 +331,23 @@ def state_in_first_order(self):
return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -345,7 +362,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -355,6 +379,19 @@ def _init_step_index(self, timestep):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -379,7 +416,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -405,7 +455,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -429,7 +491,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -558,6 +637,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index cd7a29fe675f..a7b0644de4f5 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -51,10 +51,10 @@ class LCMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -62,16 +62,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -99,13 +100,13 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -251,7 +252,23 @@ def __init__(
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -266,7 +283,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -291,7 +315,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -315,6 +339,8 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -322,6 +348,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -597,6 +631,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -619,6 +669,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -642,6 +707,17 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
+ """
+ Compute the previous timestep in the diffusion chain.
+
+ Args:
+ timestep (`int`):
+ The current timestep.
+
+ Returns:
+ `int`:
+ The previous timestep.
+ """
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py
index c2450204aa8f..d0766eed1b66 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete.py
@@ -14,7 +14,7 @@
import math
import warnings
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import scipy.stats
@@ -47,10 +47,10 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -58,16 +58,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -98,15 +99,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- num_train_timesteps (`int`, defaults to 1000):
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
+ beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
+ beta_end (`float`, defaults to `0.02`):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear` or `scaled_linear`.
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -117,14 +117,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
- timestep_spacing (`str`, defaults to `"linspace"`):
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- steps_offset (`int`, defaults to 0):
+ steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
"""
@@ -137,13 +137,13 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
- prediction_type: str = "epsilon",
- timestep_spacing: str = "linspace",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -182,7 +182,15 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def init_noise_sigma(self):
+ def init_noise_sigma(self) -> Union[float, torch.Tensor]:
+ """
+ The standard deviation of the initial noise distribution.
+
+ Returns:
+ `float` or `torch.Tensor`:
+ The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
+ the timestep spacing configuration.
+ """
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
@@ -190,26 +198,34 @@ def init_noise_sigma(self):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
- The index counter for current timestep. It will increase 1 after each scheduler step.
+ The index counter for current timestep. It will increase by 1 after each scheduler step.
+
+ Returns:
+ `int` or `None`:
+ The current step index, or `None` if not initialized.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+
+ Returns:
+ `int` or `None`:
+ The begin index for the scheduler, or `None` if not set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -238,14 +254,21 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample
- def get_lms_coefficient(self, order, t, current_order):
+ def get_lms_coefficient(self, order: int, t: int, current_order: int) -> float:
"""
Compute the linear multistep coefficient.
Args:
- order ():
- t ():
- current_order ():
+ order (`int`):
+ The order of the linear multistep method.
+ t (`int`):
+ The current timestep index.
+ current_order (`int`):
+ The current order for which to compute the coefficient.
+
+ Returns:
+ `float`:
+ The computed linear multistep coefficient.
"""
def lms_derivative(tau):
@@ -260,7 +283,7 @@ def lms_derivative(tau):
return integrated_coeff
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -319,7 +342,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.derivatives = []
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -334,7 +373,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -343,7 +389,20 @@ def _init_step_index(self, timestep):
self._step_index = self._begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -366,9 +425,19 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t
- # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
@@ -382,7 +451,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -406,7 +487,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -521,6 +619,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -548,5 +661,5 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py
index c07621179e2b..651532b06ddb 100644
--- a/src/diffusers/schedulers/scheduling_pndm.py
+++ b/src/diffusers/schedulers/scheduling_pndm.py
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -26,10 +26,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -37,16 +37,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -78,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- num_train_timesteps (`int`, defaults to 1000):
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
+ beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
+ beta_end (`float`, defaults to `0.02`):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
skip_prk_steps (`bool`, defaults to `False`):
@@ -96,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
- or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
- paper).
- timestep_spacing (`str`, defaults to `"leading"`):
+ or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- steps_offset (`int`, defaults to 0):
+ steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
"""
@@ -116,12 +115,12 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
- prediction_type: str = "epsilon",
- timestep_spacing: str = "leading",
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0,
):
if trained_betas is not None:
@@ -163,7 +162,7 @@ def __init__(
self.plms_timesteps = None
self.timesteps = None
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -242,7 +241,7 @@ def step(
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
@@ -275,14 +274,13 @@ def step_prk(
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
-
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -334,14 +332,13 @@ def step_plms(
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
-
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -402,19 +399,27 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens
"""
return sample
- def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
- # See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778
- # this function computes x_(t−δ) using the formula of (9)
- # Note that x_t needs to be added to both sides of the equation
-
- # Notation ( ->
- # alpha_prod_t -> α_t
- # alpha_prod_t_prev -> α_(t−δ)
- # beta_prod_t -> (1 - α_t)
- # beta_prod_t_prev -> (1 - α_(t−δ))
- # sample -> x_t
- # model_output -> e_θ(x_t, t)
- # prev_sample -> x_(t−δ)
+ def _get_prev_sample(
+ self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM
+ paper](https://huggingface.co/papers/2202.09778).
+
+ Args:
+ sample (`torch.Tensor`):
+ The current sample x_t.
+ timestep (`int`):
+ The current timestep t.
+ prev_timestep (`int`):
+ The previous timestep (t-δ).
+ model_output (`torch.Tensor`):
+ The model output e_θ(x_t, t).
+
+ Returns:
+ `torch.Tensor`:
+ The previous sample x_(t-δ).
+ """
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -452,6 +457,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -472,5 +493,5 @@ def add_noise(
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py
index 6530c5af9e5b..a2eaf8eb3abd 100644
--- a/src/diffusers/schedulers/scheduling_repaint.py
+++ b/src/diffusers/schedulers/scheduling_repaint.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -45,10 +45,10 @@ class RePaintSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -56,16 +56,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py
index 2979ce193a36..d9054c39c9de 100644
--- a/src/diffusers/schedulers/scheduling_sasolver.py
+++ b/src/diffusers/schedulers/scheduling_sasolver.py
@@ -16,7 +16,7 @@
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math
-from typing import Callable, List, Optional, Tuple, Union
+from typing import Callable, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -33,10 +33,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -44,16 +44,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -253,7 +254,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -342,6 +343,8 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -349,6 +352,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -375,6 +386,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -410,7 +434,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -436,7 +473,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -460,7 +509,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -1193,6 +1259,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py
index 63b4a109ff9b..7b01d886299c 100644
--- a/src/diffusers/schedulers/scheduling_scm.py
+++ b/src/diffusers/schedulers/scheduling_scm.py
@@ -109,7 +109,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -173,7 +173,14 @@ def set_timesteps(
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -182,7 +189,23 @@ def _init_step_index(self, timestep):
self._step_index = self._begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py
index 3fd5c341eca9..37b41c87f8a2 100644
--- a/src/diffusers/schedulers/scheduling_tcd.py
+++ b/src/diffusers/schedulers/scheduling_tcd.py
@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -50,10 +50,10 @@ class TCDSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -61,16 +61,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -98,13 +99,13 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -252,7 +253,23 @@ def __init__(
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -267,7 +284,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -292,7 +316,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -316,6 +340,24 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance
def _get_variance(self, timestep, prev_timestep):
+ """
+ Computes the variance of the noise added at a given diffusion step.
+
+ For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
+ literature:
+ var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+ where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
+
+ Args:
+ timestep (`int`):
+ The current timestep in the diffusion process.
+ prev_timestep (`int`):
+ The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
+
+ Returns:
+ `torch.Tensor`:
+ The variance for the current timestep.
+ """
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -328,6 +370,8 @@ def _get_variance(self, timestep, prev_timestep):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -335,6 +379,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -634,6 +686,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -656,6 +724,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -679,6 +762,17 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
+ """
+ Compute the previous timestep in the diffusion chain.
+
+ Args:
+ timestep (`int`):
+ The current timestep.
+
+ Returns:
+ `int`:
+ The previous timestep.
+ """
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py
index d78efabfbc57..5a978dec649b 100644
--- a/src/diffusers/schedulers/scheduling_unclip.py
+++ b/src/diffusers/schedulers/scheduling_unclip.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -46,10 +46,10 @@ class UnCLIPSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -57,16 +57,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -334,6 +335,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 162a34bd2774..7dc5f467680b 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -16,7 +16,7 @@
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,10 +32,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -43,16 +43,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -80,13 +81,13 @@ def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
-
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -297,7 +298,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -432,6 +433,8 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
@@ -439,6 +442,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -465,6 +476,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -500,7 +524,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -526,7 +563,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -550,7 +599,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 6884d3be9292..61f174ace8e6 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -41,7 +41,7 @@
from .deprecation_utils import _maybe_remap_transformers_class, deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
-from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
+from .export_utils import export_to_gif, export_to_merged_video_audio, export_to_obj, export_to_ply, export_to_video
from .hub_utils import (
PushToHubMixin,
_add_variant,
@@ -125,7 +125,7 @@
is_xformers_version,
requires_backends,
)
-from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
+from .loading_utils import get_module_from_name, get_submodule_by_name, load_audio, load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index a18f28606b3e..c46fa4363483 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -46,7 +46,6 @@
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
-DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 81eb2569e303..f56a8b932505 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1623,6 +1623,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class WanAnimateTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class WanTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 19f6c0f58440..ecc9b8eb3cf5 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2147,6 +2147,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class SanaImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class SanaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -3512,6 +3527,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class WanAnimatePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class WanImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -3542,6 +3572,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class WanSpeechToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class WanVACEPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py
index 07cf46928a44..6fc79418121d 100644
--- a/src/diffusers/utils/export_utils.py
+++ b/src/diffusers/utils/export_utils.py
@@ -1,6 +1,9 @@
import io
+import os
import random
+import shutil
import struct
+import subprocess
import tempfile
from contextlib import contextmanager
from typing import List, Optional, Union
@@ -207,3 +210,62 @@ def export_to_video(
writer.append_data(frame)
return output_video_path
+
+
+def export_to_merged_video_audio(video_path: str, audio_path: str):
+ """
+ Merge the video and audio into a new video, with the duration set to the shorter of the two, and overwrite the
+ original video file.
+
+ Parameters:
+ video_path (str): Path to the original video file
+ audio_path (str): Path to the audio file
+ """
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"video file {video_path} does not exist")
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"audio file {audio_path} does not exist")
+
+ base, ext = os.path.splitext(video_path)
+ temp_output = f"{base}_temp{ext}"
+
+ try:
+ # Create ffmpeg command
+ command = [
+ "ffmpeg",
+ "-y", # overwrite
+ "-i",
+ video_path,
+ "-i",
+ audio_path,
+ "-c:v",
+ "copy", # copy video stream
+ "-c:a",
+ "aac", # use AAC audio encoder
+ "-b:a",
+ "192k", # set audio bitrate (optional)
+ "-map",
+ "0:v:0", # select the first video stream
+ "-map",
+ "1:a:0", # select the first audio stream
+ "-shortest", # choose the shortest duration
+ temp_output,
+ ]
+
+ # Execute the command
+ logger.info("Start merging video and audio...")
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
+
+ # Check result
+ if result.returncode != 0:
+ error_msg = f"FFmpeg execute failed: {result.stderr}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg)
+
+ shutil.move(temp_output, video_path)
+ logger.info(f"Merge completed, saved to {video_path}")
+
+ except Exception as e:
+ if os.path.exists(temp_output):
+ os.remove(temp_output)
+ logger.error(f"merge_video_audio failed with error: {e}")
diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py
deleted file mode 100644
index 26d6e3972fb7..000000000000
--- a/src/diffusers/utils/kernels_utils.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from ..utils import get_logger
-from .import_utils import is_kernels_available
-
-
-logger = get_logger(__name__)
-
-
-_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
-
-
-def _get_fa3_from_hub():
- if not is_kernels_available():
- return None
- else:
- from kernels import get_kernel
-
- try:
- # TODO: temporary revision for now. Remove when merged upstream into `main`.
- flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
- return flash_attn_3_hub
- except Exception as e:
- logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
- raise
diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py
index dd23ae73c861..29e8a7855fdd 100644
--- a/src/diffusers/utils/loading_utils.py
+++ b/src/diffusers/utils/loading_utils.py
@@ -3,6 +3,8 @@
from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import unquote, urlparse
+import librosa
+import numpy
import PIL.Image
import PIL.ImageOps
import requests
@@ -57,6 +59,9 @@ def load_image(
def load_video(
video: str,
convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
+ n_frames: Optional[int] = None,
+ target_fps: Optional[int] = None,
+ reverse: bool = False,
) -> List[PIL.Image.Image]:
"""
Loads `video` to a list of PIL Image.
@@ -67,6 +72,13 @@ def load_video(
convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
A conversion method to apply to the video after loading it. When set to `None` the images will be converted
to "RGB".
+ n_frames (`int`, *optional*):
+ Number of frames to sample from the video. If None, all frames are loaded.
+ target_fps (`int`, *optional*):
+ Target sampling frame rate. If None, uses original frame rate.
+ reverse (`bool`, *optional*):
+ If True, samples frames starting from the beginning of the video; if False, samples frames starting from
+ the end. Defaults to False.
Returns:
`List[PIL.Image.Image]`:
@@ -125,9 +137,40 @@ def load_video(
)
with imageio.get_reader(video) as reader:
- # Read all frames
- for frame in reader:
- pil_images.append(PIL.Image.fromarray(frame))
+ # Determine which frames to sample
+ if n_frames is not None and target_fps is not None:
+ # Get video metadata
+ total_frames = reader.count_frames()
+ original_fps = reader.get_meta_data().get("fps")
+
+ # Calculate sampling interval based on target fps
+ interval = max(1, round(original_fps / target_fps))
+ required_span = (n_frames - 1) * interval
+
+ if reverse:
+ start_frame = 0
+ else:
+ start_frame = max(0, total_frames - required_span - 1)
+
+ # Generate sampling indices
+ sampled_indices = []
+ for i in range(n_frames):
+ indice = start_frame + i * interval
+ if indice >= total_frames:
+ break
+ sampled_indices.append(int(indice))
+
+ # Read specific frames
+ for idx in sampled_indices:
+ try:
+ frame = reader.get_data(idx)
+ pil_images.append(PIL.Image.fromarray(frame))
+ except IndexError:
+ break
+ else:
+ # Read all frames
+ for frame in reader:
+ pil_images.append(PIL.Image.fromarray(frame))
if was_tempfile_created:
os.remove(video_path)
@@ -138,6 +181,53 @@ def load_video(
return pil_images
+def load_audio(
+ audio: Union[str, numpy.ndarray], convert_method: Optional[Callable[[numpy.ndarray], numpy.ndarray]] = None
+) -> numpy.ndarray:
+ """
+ Loads `audio` to a numpy array.
+
+ Args:
+ audio (`str` or `numpy.ndarray`):
+ The audio to convert to the numpy array format.
+ convert_method (Callable[[numpy.ndarray], numpy.ndarray], *optional*):
+ A conversion method to apply to the audio after loading it. When set to `None` the audio will be converted
+ to a specific format.
+
+ Returns:
+ `numpy.ndarray`:
+ A Librosa audio object.
+ `int`:
+ The sample rate of the audio.
+ """
+ if isinstance(audio, str):
+ if audio.startswith("http://") or audio.startswith("https://"):
+ # Download audio from URL and load with librosa
+ response = requests.get(audio, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
+ for chunk in response.iter_content(chunk_size=8192):
+ temp_file.write(chunk)
+ temp_audio_path = temp_file.name
+
+ audio, sample_rate = librosa.load(temp_audio_path, sr=16000)
+ os.remove(temp_audio_path) # Clean up temporary file
+ elif os.path.isfile(audio):
+ audio, sample_rate = librosa.load(audio, sr=16000)
+ else:
+ raise ValueError(
+ f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {audio} is not a valid path."
+ )
+ elif isinstance(audio, numpy.ndarray):
+ audio = audio
+ sample_rate = 16000 # Default sample rate for numpy arrays
+ else:
+ raise ValueError(
+ "Incorrect format used for the audio. Should be a URL linking to an audio, a local path, or a numpy array."
+ )
+
+ return audio, sample_rate
+
+
# Taken from `transformers`.
def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
if "." in tensor_name:
diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py
index f760a1bf7261..3b66fdadbef8 100644
--- a/src/diffusers/utils/torch_utils.py
+++ b/src/diffusers/utils/torch_utils.py
@@ -242,8 +242,8 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
def apply_freeu(
resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs
) -> Tuple["torch.Tensor", "torch.Tensor"]:
- """Applies the FreeU mechanism as introduced in https:
- //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
+ """Applies the FreeU mechanism as introduced in https://huggingface.co/papers/2309.11497. Adapted from the official
+ code repository: https://github.com/ChenyangSi/FreeU.
Args:
resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py
index abeb30bca102..1841184c9d6d 100644
--- a/src/diffusers/video_processor.py
+++ b/src/diffusers/video_processor.py
@@ -26,7 +26,9 @@
class VideoProcessor(VaeImageProcessor):
r"""Simple video processor."""
- def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor:
+ def preprocess_video(
+ self, video, height: Optional[int] = None, width: Optional[int] = None, resize_mode: str = "default"
+ ) -> torch.Tensor:
r"""
Preprocesses input video(s).
@@ -50,6 +52,9 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
the default width.
+ resize_mode (`str`, *optional*, defaults to `default`):
+ The resize mode, can be one of `default`, `fill`, `crop`, or `center_crop`. See
+ `VaeImageProcessor.preprocess` for detailed descriptions of each mode.
"""
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
warnings.warn(
@@ -80,7 +85,9 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
)
- video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0)
+ video = torch.stack(
+ [self.preprocess(img, height=height, width=width, resize_mode=resize_mode) for img in video], dim=0
+ )
# move the number of channels before the number of frames.
video = video.permute(0, 2, 1, 3, 4)
diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py
index d34001e7b903..b1b5531d0134 100644
--- a/tests/models/autoencoders/test_models_autoencoder_dc.py
+++ b/tests/models/autoencoders/test_models_autoencoder_dc.py
@@ -82,3 +82,7 @@ def prepare_init_args_and_inputs_for_common(self):
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_memory(self):
+ super().test_layerwise_casting_memory()
diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py
new file mode 100644
index 000000000000..5d571b8c2e7d
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_wan_animate.py
@@ -0,0 +1,126 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import WanAnimateTransformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = WanAnimateTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ clip_seq_len = 12
+ clip_dim = 16
+
+ inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
+ face_height = 16 # Should be square and match `motion_encoder_size` below
+ face_width = 16
+
+ hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+ clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
+ pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
+ torch_device
+ )
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_hidden_states_image": clip_ref_features,
+ "pose_hidden_states": pose_latents,
+ "face_pixel_values": face_pixel_values,
+ }
+
+ @property
+ def input_shape(self):
+ return (12, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ # Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
+ # contain the vast majority of the parameters in the test model
+ channel_sizes = {"4": 16, "8": 16, "16": 16}
+
+ init_dict = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 12, # 2 * C + 4 = 2 * 4 + 4 = 12
+ "latent_channels": 4,
+ "out_channels": 4,
+ "text_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "image_dim": 16,
+ "rope_max_seq_len": 32,
+ "motion_encoder_channel_sizes": channel_sizes, # Start of Wan Animate-specific config
+ "motion_encoder_size": 16, # Ensures that there will be 2 motion encoder resblocks
+ "motion_style_dim": 8,
+ "motion_dim": 4,
+ "motion_encoder_dim": 16,
+ "face_encoder_hidden_dim": 16,
+ "face_encoder_num_heads": 2,
+ "inject_face_latents_blocks": 2,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"WanAnimateTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ # Override test_output because the transformer output is expected to have less channels than the main transformer
+ # input.
+ def test_output(self):
+ expected_output_shape = (1, 4, 21, 16, 16)
+ super().test_output(expected_output_shape=expected_output_shape)
+
+
+class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = WanAnimateTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py
index 2e5a2fc82bb6..273e3f9c0721 100644
--- a/tests/others/test_attention_backends.py
+++ b/tests/others/test_attention_backends.py
@@ -7,7 +7,6 @@
```bash
export RUN_ATTENTION_BACKEND_TESTS=yes
-export DIFFUSERS_ENABLE_HUB_KERNELS=yes
pytest tests/others/test_attention_backends.py
```
diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py
index 14ff1272a29e..5ccba1dabbfe 100644
--- a/tests/pipelines/audioldm2/test_audioldm2.py
+++ b/tests/pipelines/audioldm2/test_audioldm2.py
@@ -21,11 +21,9 @@
import pytest
import torch
from transformers import (
- ClapAudioConfig,
ClapConfig,
ClapFeatureExtractor,
ClapModel,
- ClapTextConfig,
GPT2Config,
GPT2LMHeadModel,
RobertaTokenizer,
@@ -111,33 +109,33 @@ def get_dummy_components(self):
latent_channels=4,
)
torch.manual_seed(0)
- text_branch_config = ClapTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- audio_branch_config = ClapAudioConfig(
- spec_size=8,
- window_size=4,
- num_mel_bins=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- depths=[1, 1],
- num_attention_heads=[1, 1],
- num_hidden_layers=1,
- hidden_size=192,
- projection_dim=8,
- patch_size=2,
- patch_stride=2,
- patch_embed_input_channels=4,
- )
+ text_branch_config = {
+ "bos_token_id": 0,
+ "eos_token_id": 2,
+ "hidden_size": 8,
+ "intermediate_size": 37,
+ "layer_norm_eps": 1e-05,
+ "num_attention_heads": 1,
+ "num_hidden_layers": 1,
+ "pad_token_id": 1,
+ "vocab_size": 1000,
+ "projection_dim": 8,
+ }
+ audio_branch_config = {
+ "spec_size": 8,
+ "window_size": 4,
+ "num_mel_bins": 8,
+ "intermediate_size": 37,
+ "layer_norm_eps": 1e-05,
+ "depths": [1, 1],
+ "num_attention_heads": [1, 1],
+ "num_hidden_layers": 1,
+ "hidden_size": 192,
+ "projection_dim": 8,
+ "patch_size": 2,
+ "patch_stride": 2,
+ "patch_embed_input_channels": 4,
+ }
text_encoder_config = ClapConfig(
text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
index 476fc584cc56..62f5853da9a5 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
@@ -23,7 +23,7 @@
KandinskyV22InpaintCombinedPipeline,
)
-from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
+from ...testing_utils import enable_full_determinism, require_accelerator, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
from .test_kandinsky import Dummies
from .test_kandinsky_img2img import Dummies as Img2ImgDummies
@@ -402,6 +402,7 @@ def test_save_load_local(self):
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
+ @require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
index d4eb650263af..8a693e9c2dd0 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
@@ -37,6 +37,7 @@
load_image,
load_numpy,
numpy_cosine_similarity_distance,
+ require_accelerator,
require_torch_accelerator,
slow,
torch_device,
@@ -254,6 +255,7 @@ def test_model_cpu_offload_forward_pass(self):
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
+ @require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
diff --git a/tests/pipelines/sana_video/__init__.py b/tests/pipelines/sana_video/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana_video/test_sana_video.py
similarity index 100%
rename from tests/pipelines/sana/test_sana_video.py
rename to tests/pipelines/sana_video/test_sana_video.py
diff --git a/tests/pipelines/sana_video/test_sana_video_i2v.py b/tests/pipelines/sana_video/test_sana_video_i2v.py
new file mode 100644
index 000000000000..36a646ca528f
--- /dev/null
+++ b/tests/pipelines/sana_video/test_sana_video_i2v.py
@@ -0,0 +1,238 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ SanaImageToVideoPipeline,
+ SanaVideoTransformer3DModel,
+)
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ torch.manual_seed(0)
+ transformer = SanaVideoTransformer3DModel(
+ in_channels=16,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=2,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=12,
+ cross_attention_dim=24,
+ caption_channels=8,
+ mlp_ratio=2.5,
+ dropout=0.0,
+ attention_bias=False,
+ sample_size=8,
+ patch_size=(1, 2, 2),
+ norm_elementwise_affine=False,
+ norm_eps=1e-6,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ # Create a dummy image input (PIL Image)
+ image = Image.new("RGB", (32, 32))
+
+ inputs = {
+ "image": image,
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": [],
+ "use_resolution_binning": False,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_save_load_local(self, expected_max_difference=5e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skip("Skipping fp16 test as model is trained with bf16")
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
+
+ @unittest.skip("Skipping fp16 test as model is trained with bf16")
+ def test_save_load_float16(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_save_load_float16(expected_max_diff=0.2)
+
+
+@slow
+@require_torch_accelerator
+class SanaVideoPipelineIntegrationTests(unittest.TestCase):
+ prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_sana_video_480p(self):
+ pass
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
index 2e4b428dfeb5..285c2fea7ebc 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
@@ -37,6 +37,7 @@
floats_tensor,
load_image,
load_numpy,
+ require_accelerator,
require_torch_accelerator,
slow,
torch_device,
@@ -222,6 +223,7 @@ def test_stable_diffusion_latent_upscaler_multiple_init_images(self):
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=7e-3)
+ @require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=3e-3)
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 2af4ad0314c3..e2bbce7b0ead 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -1422,7 +1422,18 @@ def test_float16_inference(self, expected_max_diff=5e-2):
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
- if hasattr(module, "half"):
+ # Account for components with _keep_in_fp32_modules
+ if hasattr(module, "_keep_in_fp32_modules") and module._keep_in_fp32_modules is not None:
+ for name, param in module.named_parameters():
+ if any(
+ module_to_keep_in_fp32 in name.split(".")
+ for module_to_keep_in_fp32 in module._keep_in_fp32_modules
+ ):
+ param.data = param.data.to(torch_device).to(torch.float32)
+ else:
+ param.data = param.data.to(torch_device).to(torch.float16)
+
+ elif hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
diff --git a/tests/pipelines/wan/test_wan_animate.py b/tests/pipelines/wan/test_wan_animate.py
new file mode 100644
index 000000000000..d6d1b09f3620
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_animate.py
@@ -0,0 +1,239 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ WanAnimatePipeline,
+ WanAnimateTransformer3DModel,
+)
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanAnimatePipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ channel_sizes = {"4": 16, "8": 16, "16": 16}
+ transformer = WanAnimateTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ latent_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ image_dim=4,
+ rope_max_seq_len=32,
+ motion_encoder_channel_sizes=channel_sizes,
+ motion_encoder_size=16,
+ motion_style_dim=8,
+ motion_dim=4,
+ motion_encoder_dim=16,
+ face_encoder_hidden_dim=16,
+ face_encoder_num_heads=2,
+ inject_face_latents_blocks=2,
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=4,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=4, size=4)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ num_frames = 17
+ height = 16
+ width = 16
+ face_height = 16
+ face_width = 16
+
+ image = Image.new("RGB", (height, width))
+ pose_video = [Image.new("RGB", (height, width))] * num_frames
+ face_video = [Image.new("RGB", (face_height, face_width))] * num_frames
+
+ inputs = {
+ "image": image,
+ "pose_video": pose_video,
+ "face_video": face_video,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "height": height,
+ "width": width,
+ "segment_frame_length": 77, # TODO: can we set this to num_frames?
+ "num_inference_steps": 2,
+ "mode": "animate",
+ "prev_segment_conditioning_frames": 1,
+ "generator": generator,
+ "guidance_scale": 1.0,
+ "output_type": "pt",
+ "max_sequence_length": 16,
+ }
+ return inputs
+
+ def test_inference(self):
+ """Test basic inference in animation mode."""
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames[0]
+ self.assertEqual(video.shape, (17, 3, 16, 16))
+
+ expected_video = torch.randn(17, 3, 16, 16)
+ max_diff = np.abs(video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_inference_replacement(self):
+ """Test the pipeline in replacement mode with background and mask videos."""
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["mode"] = "replace"
+ num_frames = 17
+ height = 16
+ width = 16
+ inputs["background_video"] = [Image.new("RGB", (height, width))] * num_frames
+ inputs["mask_video"] = [Image.new("L", (height, width))] * num_frames
+
+ video = pipe(**inputs).frames[0]
+ self.assertEqual(video.shape, (17, 3, 16, 16))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "Setting the Wan Animate latents to zero at the last denoising step does not guarantee that the output will be"
+ " zero. I believe this is because the latents are further processed in the outer loop where we loop over"
+ " inference segments."
+ )
+ def test_callback_inputs(self):
+ pass
+
+
+@slow
+@require_torch_accelerator
+class WanAnimatePipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_wan_animate(self):
+ pass
diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py
new file mode 100644
index 000000000000..7396a151b3be
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_speech_to_video.py
@@ -0,0 +1,244 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor
+
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ WanS2VTransformer3DModel,
+ WanSpeechToVideoPipeline,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanSpeechToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanSpeechToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanS2VTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=3,
+ num_weighted_avg_layers=5,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ audio_dim=16,
+ audio_inject_layers=[0, 2],
+ enable_adain=True,
+ enable_framepack=True,
+ )
+
+ torch.manual_seed(0)
+ audio_encoder = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
+ audio_processor = Wav2Vec2Processor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "audio_encoder": audio_encoder,
+ "audio_processor": audio_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ # Use 64x64 so that after VAE downsampling (factor ~8) latent spatial size is 8x8, which matches
+ # the frame-packing conv kernel requirement. The largest kernel is (4, 8, 8) so we need at least 8x8 latents.
+ height = 64
+ width = 64
+
+ image = Image.new("RGB", (width, height))
+
+ sampling_rate = 16000
+ audio_length = 0.5
+ # Make audio generation deterministic by using a fixed seed
+ np_rng = np.random.RandomState(seed)
+ audio = np_rng.rand(int(sampling_rate * audio_length)).astype(np.float32)
+
+ inputs = {
+ "image": image,
+ "audio": audio,
+ "sampling_rate": sampling_rate,
+ "prompt": "A person speaking",
+ "negative_prompt": "low quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.5,
+ "height": height,
+ "width": width,
+ "num_frames_per_chunk": 4,
+ "num_chunks": 2,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "pose_video_path_or_url": None,
+ "init_first_frame": True,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames[0]
+ expected_num_frames = inputs["num_frames_per_chunk"] * inputs["num_chunks"]
+ if not inputs["init_first_frame"]:
+ expected_num_frames -= 3
+ self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"]))
+
+ def test_inference_with_pose(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["pose_video_path_or_url"] = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4"
+ video = pipe(**inputs).frames[0]
+ expected_num_frames = inputs["num_frames_per_chunk"] * inputs["num_chunks"]
+ if not inputs["init_first_frame"]:
+ expected_num_frames -= 3
+ self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"]))
+
+ def test_callback_cfg(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ if "guidance_scale" not in sig.parameters:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_increase_guidance(pipe, i, t, callback_kwargs):
+ pipe._guidance_scale += 1.0
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # use cfg guidance because some pipelines modify the shape of the latents
+ # outside of the denoising loop
+ inputs["guidance_scale"] = 2.0
+ inputs["callback_on_step_end"] = callback_increase_guidance
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ _ = pipe(**inputs)[0]
+
+ # we increase the guidance scale by 1.0 at every step
+ # check that the guidance scale is increased by the number of scheduler timesteps
+ # accounts for models that modify the number of inference steps based on strength.
+ # For this pipeline, the total number of timesteps is multiplied by num_chunks
+ # since each chunk runs independently with its own denoising loop
+ assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps * inputs["num_chunks"])
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("Batching is not yet supported with this pipeline")
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip("Batching is not yet supported with this pipeline")
+ def test_inference_batch_single_identical(self):
+ return super().test_inference_batch_single_identical()
+
+ @unittest.skip(
+ "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
+ )
+ def test_float16_inference(self):
+ pass
+
+ @unittest.skip(
+ "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
+ )
+ def test_save_load_float16(self):
+ pass
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index 0f4fd408a7c1..98969b55b727 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -16,6 +16,8 @@
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
+ WanAnimateTransformer3DModel,
+ WanS2VTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
@@ -721,6 +723,60 @@ def get_dummy_inputs(self):
}
+class WanS2VGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-S2V-14B-GGUF/blob/main/Wan2.2-S2V-14B-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanS2VTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states": torch.randn(
+ (1, 96, 2, 64, 64),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states_scale": torch.randn(
+ (8,),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanAnimateGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanAnimateTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states": torch.randn(
+ (1, 96, 2, 64, 64),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states_scale": torch.randn(
+ (8,),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
torch_dtype = torch.bfloat16
diff --git a/utils/check_doc_toc.py b/utils/check_doc_toc.py
index 0dd02cde86c1..050b093991e6 100644
--- a/utils/check_doc_toc.py
+++ b/utils/check_doc_toc.py
@@ -21,20 +21,23 @@
PATH_TO_TOC = "docs/source/en/_toctree.yml"
+# Titles that should maintain their position and not be sorted alphabetically
+FIXED_POSITION_TITLES = {"overview", "autopipeline"}
+
def clean_doc_toc(doc_list):
"""
Cleans the table of content of the model documentation by removing duplicates and sorting models alphabetically.
"""
counts = defaultdict(int)
- overview_doc = []
+ fixed_position_docs = []
new_doc_list = []
for doc in doc_list:
if "local" in doc:
counts[doc["local"]] += 1
- if doc["title"].lower() == "overview":
- overview_doc.append({"local": doc["local"], "title": doc["title"]})
+ if doc["title"].lower() in FIXED_POSITION_TITLES:
+ fixed_position_docs.append({"local": doc["local"], "title": doc["title"]})
else:
new_doc_list.append(doc)
@@ -57,14 +60,13 @@ def clean_doc_toc(doc_list):
new_doc.extend([doc for doc in doc_list if "local" not in counts or counts[doc["local"]] == 1])
new_doc = sorted(new_doc, key=lambda s: s["title"].lower())
- # "overview" gets special treatment and is always first
- if len(overview_doc) > 1:
- raise ValueError("{doc_list} has two 'overview' docs which is not allowed.")
-
- overview_doc.extend(new_doc)
+ # Fixed-position titles maintain their original order
+ result = []
+ for doc in fixed_position_docs:
+ result.append(doc)
- # Sort
- return overview_doc
+ result.extend(new_doc)
+ return result
def check_scheduler_doc(overwrite=False):