From f3db38c1e74c49de0ff42356515e7b394561ec39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 12 Nov 2025 09:37:21 -0700 Subject: [PATCH 01/35] ArXiv -> HF Papers (#12583) * Update pipeline_skyreels_v2_i2v.py * Update README.md * Update torch_utils.py * Update torch_utils.py * Update guider_utils.py * Update pipeline_ltx.py * Update pipeline_bria.py * Apply suggestion from @qgallouedec * Update autoencoder_kl_qwenimage.py * Update pipeline_prx.py * Update pipeline_wan_vace.py * Update pipeline_skyreels_v2.py * Update pipeline_skyreels_v2_diffusion_forcing.py * Update pipeline_bria_fibo.py * Update pipeline_skyreels_v2_diffusion_forcing_i2v.py * Update pipeline_ltx_condition.py * Update pipeline_ltx_image2video.py * Update regional_prompting_stable_diffusion.py * make style * style * style --- examples/community/README.md | 2 +- .../community/regional_prompting_stable_diffusion.py | 10 +++++----- src/diffusers/guiders/guider_utils.py | 2 +- .../models/autoencoders/autoencoder_kl_qwenimage.py | 2 +- src/diffusers/pipelines/bria/pipeline_bria.py | 12 ++++++------ .../pipelines/bria_fibo/pipeline_bria_fibo.py | 12 ++++++------ src/diffusers/pipelines/ltx/pipeline_ltx.py | 9 +++++---- .../pipelines/ltx/pipeline_ltx_condition.py | 9 +++++---- .../pipelines/ltx/pipeline_ltx_image2video.py | 9 +++++---- src/diffusers/pipelines/prx/pipeline_prx.py | 10 +++++----- .../pipelines/skyreels_v2/pipeline_skyreels_v2.py | 10 +++++----- .../pipeline_skyreels_v2_diffusion_forcing.py | 10 +++++----- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 10 +++++----- .../skyreels_v2/pipeline_skyreels_v2_i2v.py | 10 +++++----- src/diffusers/pipelines/wan/pipeline_wan_vace.py | 10 +++++----- src/diffusers/utils/torch_utils.py | 4 ++-- 16 files changed, 67 insertions(+), 64 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 4a4b0f5fd9f5..69e9c7576103 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -5488,7 +5488,7 @@ Editing at Scale", many thanks to their contribution! This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated. -As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work. +As explained in Section 3 of [the paper](https://huggingface.co/papers/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work. ## Example Usage diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index bca67e3959d8..3bc780cfcf7a 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -490,7 +490,7 @@ def hook_forwards(root_module: torch.nn.Module): def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) @@ -841,7 +841,7 @@ def stable_diffusion_call( num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make @@ -872,7 +872,7 @@ def stable_diffusion_call( [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that @@ -1062,7 +1062,7 @@ def stable_diffusion_call( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # Based on 3.4. in https://huggingface.co/papers/2305.08891 noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 @@ -1668,7 +1668,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): r""" Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Flawed](https://huggingface.co/papers/2305.08891). Args: noise_cfg (`torch.Tensor`): diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 52cb0ce34980..6c328328fc3b 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -373,7 +373,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): r""" Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Flawed](https://huggingface.co/papers/2305.08891). Args: noise_cfg (`torch.Tensor`): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 0aadbad9f4de..618801dfb605 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -16,7 +16,7 @@ # QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. # For more information about the Wan VAE, please refer to: # - GitHub: https://github.com/Wan-Video/Wan2.1 -# - arXiv: https://arxiv.org/abs/2503.20314 +# - Paper: https://huggingface.co/papers/2503.20314 from typing import List, Optional, Tuple, Union diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index ebddfb0c0eee..a22a756005ac 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -245,7 +245,7 @@ def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): @@ -489,11 +489,11 @@ def __call__( in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 5.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index 85d29029e667..c66b64766edc 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -337,7 +337,7 @@ def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property @@ -498,11 +498,11 @@ def __call__( in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 5.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index bd23e657c408..8ca8b4419e18 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -590,9 +590,10 @@ def __call__( the text `prompt`, usually at the expense of lower image quality. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -777,7 +778,7 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.guidance_rescale > 0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # Based on 3.4. in https://huggingface.co/papers/2305.08891 noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale ) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 537588f67c95..48a6f0837c8d 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -927,9 +927,10 @@ def __call__( the text `prompt`, usually at the expense of lower image quality. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -1194,7 +1195,7 @@ def __call__( timestep, _ = timestep.chunk(2) if self.guidance_rescale > 0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # Based on 3.4. in https://huggingface.co/papers/2305.08891 noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale ) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 694378b4f040..f30f8a3dc8f6 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -654,9 +654,10 @@ def __call__( the text `prompt`, usually at the expense of lower image quality. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -851,7 +852,7 @@ def __call__( timestep, _ = timestep.chunk(2) if self.guidance_rescale > 0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # Based on 3.4. in https://huggingface.co/papers/2305.08891 noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale ) diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py index a3bd3e6b45e7..df598a5715d2 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -536,11 +536,11 @@ def __call__( in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index 8562a5eaf0e6..d6cd7d7feceb 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -415,11 +415,11 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, defaults to `6.0`): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index d0a4e118ce43..089f92632d38 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -647,11 +647,11 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, defaults to `6.0`): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 959cbb32f23a..2951a9447386 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -698,11 +698,11 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, defaults to `5.0`): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index d59b4ce3cb17..d61b687eadc3 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -524,11 +524,11 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, defaults to `5.0`): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 63e557a98fbe..351ae2e70563 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -758,11 +758,11 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, defaults to `5.0`): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. guidance_scale_2 (`float`, *optional*, defaults to `None`): Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index f760a1bf7261..3b66fdadbef8 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -242,8 +242,8 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T def apply_freeu( resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs ) -> Tuple["torch.Tensor", "torch.Tensor"]: - """Applies the FreeU mechanism as introduced in https: - //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. + """Applies the FreeU mechanism as introduced in https://huggingface.co/papers/2309.11497. Adapted from the official + code repository: https://github.com/ChenyangSi/FreeU. Args: resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied. From 2f44d630460666c7e8acd54a54e81340ad71526f Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Wed, 12 Nov 2025 09:21:24 -0800 Subject: [PATCH 02/35] [docs] Update install instructions (#12626) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit remove commit Removed specific commit reference for installation instructions. Co-authored-by: Sayak Paul Co-authored-by: Álvaro Somoza --- examples/dreambooth/README_flux.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 242f018b654b..42edbb122136 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -268,12 +268,11 @@ provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_f **important** > [!NOTE] -> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below. +> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source. > To do this, execute the following steps in a new virtual environment: > ``` > git clone https://github.com/huggingface/diffusers > cd diffusers -> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b > pip install -e . > ``` From d6c63bb956358f1990443a849ca250419a238b95 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 12 Nov 2025 07:59:18 -1000 Subject: [PATCH 03/35] [modular] add a check (#12628) * add * fix --- src/diffusers/modular_pipelines/modular_pipeline.py | 4 ++++ src/diffusers/modular_pipelines/qwenimage/modular_blocks.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 151adbbc0320..c7285e38fda2 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -861,6 +861,10 @@ def __init__(self): else: sub_blocks[block_name] = block self.sub_blocks = sub_blocks + if not len(self.block_names) == len(self.block_classes): + raise ValueError( + f"In {self.__class__.__name__}, the number of block_names and block_classes must be the same." + ) def _get_inputs(self): inputs = [] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 83bfcb3da4fd..419894164389 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -523,7 +523,7 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): QwenImageOptionalControlNetBeforeDenoiseStep, QwenImageAutoDenoiseStep, ] - block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise", "decode"] + block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"] @property def description(self): @@ -534,7 +534,6 @@ def description(self): + " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" + " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n" + " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n" - + " - `QwenImageAutoDecodeStep` (decode) decodes the latents into images.\n\n" + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n" + " - for image-to-image generation, you need to provide `image_latents`\n" + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n" From 44c3101685f2cb6807b7157587d4da450ac629c1 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Thu, 13 Nov 2025 02:26:10 +0100 Subject: [PATCH 04/35] Improve docstrings and type hints in scheduling_amused.py (#12623) * Improve docstrings and type hints in scheduling_amused.py - Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk) - Enhance AmusedSchedulerOutput with proper Optional typing - Add comprehensive docstrings for AmusedScheduler class - Improve __init__, set_timesteps, step, and add_noise methods - Fix type hints to match documentation conventions - All changes follow project standards from issue #9567 * Enhance type hints and docstrings in scheduling_amused.py - Update type hints for `prev_sample` and `pred_original_sample` in `AmusedSchedulerOutput` to reflect their tensor types. - Improve docstring for `gumbel_noise` to specify the output tensor's dtype and device. - Refine `AmusedScheduler` class documentation, including detailed descriptions of the masking schedule and temperature parameters. - Adjust type hints in `set_timesteps` and `step` methods for better clarity and consistency. * Apply review feedback on scheduling_amused.py - Replace generic [Amused] reference with specific [`AmusedPipeline`] reference for consistency with project documentation conventions --- src/diffusers/schedulers/scheduling_amused.py | 151 +++++++++++++++--- 1 file changed, 132 insertions(+), 19 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_amused.py b/src/diffusers/schedulers/scheduling_amused.py index 238b8d869171..a0b8fbc862b0 100644 --- a/src/diffusers/schedulers/scheduling_amused.py +++ b/src/diffusers/schedulers/scheduling_amused.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import torch @@ -9,13 +9,48 @@ from .scheduling_utils import SchedulerMixin -def gumbel_noise(t, generator=None): +def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor: + """ + Generate Gumbel noise for sampling. + + Args: + t (`torch.Tensor`): + Input tensor to match the shape and dtype of the output noise. + generator (`torch.Generator`, *optional*): + A random number generator for reproducible sampling. + + Returns: + `torch.Tensor`: + Gumbel-distributed noise with the same shape, dtype, and device as the input tensor. + """ device = generator.device if generator is not None else t.device noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device) return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20)) -def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): +def mask_by_random_topk( + mask_len: torch.Tensor, + probs: torch.Tensor, + temperature: float = 1.0, + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + """ + Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness. + + Args: + mask_len (`torch.Tensor`): + Number of tokens to mask per sample in the batch. + probs (`torch.Tensor`): + Probability scores for each token. + temperature (`float`, *optional*, defaults to 1.0): + Temperature parameter for controlling randomness in the masking process. + generator (`torch.Generator`, *optional*): + A random number generator for reproducible sampling. + + Returns: + `torch.Tensor`: + Boolean mask indicating which tokens should be masked. + """ confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator) sorted_confidence = torch.sort(confidence, dim=-1).values cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) @@ -29,28 +64,46 @@ class AmusedSchedulerOutput(BaseOutput): Output class for the scheduler's `step` function output. Args: - prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample `(x_{0})` based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. + prev_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`): + Computed sample `(x_{t-1})` of previous timestep with token IDs. `prev_sample` should be used as next model + input in the denoising loop. + pred_original_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`, *optional*): + The predicted fully denoised sample `(x_{0})` with token IDs based on the model output from the current + timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor - pred_original_sample: torch.Tensor = None + pred_original_sample: Optional[torch.Tensor] = None class AmusedScheduler(SchedulerMixin, ConfigMixin): + """ + A scheduler for masked token generation as used in [`AmusedPipeline`]. + + This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear + schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates + on discrete token IDs, making it suitable for autoregressive and non-autoregressive masked token generation models. + + This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the + generic methods the library implements for all schedulers such as loading and saving. + + Args: + mask_token_id (`int`): + The token ID used to represent masked tokens in the sequence. + masking_schedule (`Literal["cosine", "linear"]`, *optional*, defaults to `"cosine"`): + The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or `"linear"`. + """ + order = 1 - temperatures: torch.Tensor + temperatures: Optional[torch.Tensor] + timesteps: Optional[torch.Tensor] @register_to_config def __init__( self, mask_token_id: int, - masking_schedule: str = "cosine", + masking_schedule: Literal["cosine", "linear"] = "cosine", ): self.temperatures = None self.timesteps = None @@ -58,9 +111,23 @@ def __init__( def set_timesteps( self, num_inference_steps: int, - temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), - device: Union[str, torch.device] = None, - ): + temperature: Union[float, Tuple[float, float], List[float]] = (2, 0), + device: Optional[Union[str, torch.device]] = None, + ) -> None: + """ + Set the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to `(2, 0)`): + Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided, + temperatures will be linearly interpolated between the first and second values across all timesteps. If + a single value is provided, temperatures will be linearly interpolated from that value to 0.01. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps and temperatures should be moved to. If `None`, the timesteps are not + moved. + """ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0) if isinstance(temperature, (tuple, list)): @@ -71,12 +138,38 @@ def set_timesteps( def step( self, model_output: torch.Tensor, - timestep: torch.long, + timestep: int, sample: torch.LongTensor, - starting_mask_ratio: int = 1, + starting_mask_ratio: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, - ) -> Union[AmusedSchedulerOutput, Tuple]: + ) -> Union[AmusedSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]: + """ + Predict the sample at the previous timestep by masking tokens based on confidence scores. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens, + codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.LongTensor`): + A current instance of a sample created by the diffusion process. Contains token IDs, with masked + positions indicated by `mask_token_id`. + starting_mask_ratio (`float`, *optional*, defaults to 1.0): + A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being + masked at each step. + generator (`torch.Generator`, *optional*): + A random number generator for reproducible sampling. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple. + + Returns: + [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`: + If `return_dict` is `True`, [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor (`prev_sample`) and the + second element is the predicted original sample tensor (`pred_original_sample`). + """ two_dim_input = sample.ndim == 3 and model_output.ndim == 4 if two_dim_input: @@ -137,7 +230,27 @@ def step( return AmusedSchedulerOutput(prev_sample, pred_original_sample) - def add_noise(self, sample, timesteps, generator=None): + def add_noise( + self, + sample: torch.LongTensor, + timesteps: int, + generator: Optional[torch.Generator] = None, + ) -> torch.LongTensor: + """ + Add noise to a sample by randomly masking tokens according to the masking schedule. + + Args: + sample (`torch.LongTensor`): + The input sample containing token IDs to be partially masked. + timesteps (`int`): + The timestep that determines how much masking to apply. Higher timesteps result in more masking. + generator (`torch.Generator`, *optional*): + A random number generator for reproducible masking. + + Returns: + `torch.LongTensor`: + The sample with some tokens replaced by `mask_token_id` according to the masking schedule. + """ step_idx = (self.timesteps == timesteps).nonzero() ratio = (step_idx + 1) / len(self.timesteps) From d8e4805816df32ccecc070ccd6895e35cdafa723 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Wed, 12 Nov 2025 18:52:31 -0800 Subject: [PATCH 05/35] [WIP]Add Wan2.2 Animate Pipeline (Continuation of #12442 by tolgacangoz) (#12526) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Co-authored-by: Tolga Cangöz Co-authored-by: Tolga Cangöz <46008593+tolgacangoz@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + .../api/models/wan_animate_transformer_3d.md | 30 + docs/source/en/api/pipelines/wan.md | 255 +++- scripts/convert_wan_to_diffusers.py | 271 +++- src/diffusers/__init__.py | 4 + src/diffusers/image_processor.py | 4 +- src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_sana_video.py | 11 +- .../transformers/transformer_wan_animate.py | 1298 +++++++++++++++++ src/diffusers/pipelines/__init__.py | 16 +- src/diffusers/pipelines/wan/__init__.py | 3 +- .../pipelines/wan/image_processor.py | 185 +++ .../pipelines/wan/pipeline_wan_animate.py | 1204 +++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_wan_animate.py | 126 ++ tests/pipelines/wan/test_wan_animate.py | 239 +++ tests/quantization/gguf/test_gguf.py | 28 + 19 files changed, 3676 insertions(+), 33 deletions(-) create mode 100644 docs/source/en/api/models/wan_animate_transformer_3d.md create mode 100644 src/diffusers/models/transformers/transformer_wan_animate.py create mode 100644 src/diffusers/pipelines/wan/image_processor.py create mode 100644 src/diffusers/pipelines/wan/pipeline_wan_animate.py create mode 100644 tests/models/transformers/test_models_transformer_wan_animate.py create mode 100644 tests/pipelines/wan/test_wan_animate.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 55fe2a9a379f..77eacba664a2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -387,6 +387,8 @@ title: Transformer2DModel - local: api/models/transformer_temporal title: TransformerTemporalModel + - local: api/models/wan_animate_transformer_3d + title: WanAnimateTransformer3DModel - local: api/models/wan_transformer_3d title: WanTransformer3DModel title: Transformers diff --git a/docs/source/en/api/models/wan_animate_transformer_3d.md b/docs/source/en/api/models/wan_animate_transformer_3d.md new file mode 100644 index 000000000000..798afc72fb8e --- /dev/null +++ b/docs/source/en/api/models/wan_animate_transformer_3d.md @@ -0,0 +1,30 @@ + + +# WanAnimateTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team. + +The model can be loaded with the following code snippet. + +```python +from diffusers import WanAnimateTransformer3DModel + +transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-720P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## WanAnimateTransformer3DModel + +[[autodoc]] WanAnimateTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 3289a840e2b1..3993e2efd0c8 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers: - [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) - [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) - [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers) +- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers) > [!TIP] > Click on the Wan models in the right sidebar for more examples of video generation. @@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained( pipeline.to("cuda") prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards """ @@ -150,15 +151,15 @@ pipeline.transformer = torch.compile( ) prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards """ @@ -249,6 +250,220 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color. + + + +### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication + +[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team. + +*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.* + +The project page: https://humanaigc.github.io/wan-animate + +This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). + +#### Usage + +The Wan-Animate pipeline supports two modes of operation: + +1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos +2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene + +##### Prerequisites + +Before using the pipeline, you need to preprocess your reference video to extract: +- **Pose video**: Contains skeletal keypoints representing body motion +- **Face video**: Contains facial feature representations for expression control + +For replacement mode, you additionally need: +- **Background video**: The original video containing the scene +- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black) + +> [!NOTE] +> The preprocessing tools are available in the original Wan-Animate repository. Integration of these preprocessing steps into Diffusers is planned for a future release. + +The example below demonstrates how to use the Wan-Animate pipeline: + + + + +```python +import numpy as np +import torch +from diffusers import AutoencoderKLWan, WanAnimatePipeline +from diffusers.utils import export_to_video, load_image, load_video +from transformers import CLIPVisionModel + +model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = WanAnimatePipeline.from_pretrained( + model_id, vae=vae, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +# Load character image and preprocessed videos +image = load_image("path/to/character.jpg") +pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints +face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features + +# Resize image to match VAE constraints +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +image, height, width = aspect_ratio_resize(image, pipe) + +prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work" +negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn" + +# Generate animated video +output = pipe( + image=image, + pose_video=pose_video, + face_video=face_video, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=81, + guidance_scale=5.0, + mode="animation", # Animation mode (default) +).frames[0] +export_to_video(output, "animated_character.mp4", fps=16) +``` + + + + +```python +import numpy as np +import torch +from diffusers import AutoencoderKLWan, WanAnimatePipeline +from diffusers.utils import export_to_video, load_image, load_video +from transformers import CLIPVisionModel + +model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" +image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = WanAnimatePipeline.from_pretrained( + model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +# Load all required inputs for replacement mode +image = load_image("path/to/new_character.jpg") +pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints +face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features +background_video = load_video("path/to/background_video.mp4") # Original scene +mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate + +# Resize image to match video dimensions +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +image, height, width = aspect_ratio_resize(image, pipe) + +prompt = "A person seamlessly integrated into the scene with consistent lighting and environment" +negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene" + +# Replace character in background video +output = pipe( + image=image, + pose_video=pose_video, + face_video=face_video, + background_video=background_video, + mask_video=mask_video, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=81, + guidance_scale=5.0, + mode="replacement", # Replacement mode +).frames[0] +export_to_video(output, "character_replaced.mp4", fps=16) +``` + + + + +```python +import numpy as np +import torch +from diffusers import AutoencoderKLWan, WanAnimatePipeline +from diffusers.utils import export_to_video, load_image, load_video +from transformers import CLIPVisionModel + +model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" +image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = WanAnimatePipeline.from_pretrained( + model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +image = load_image("path/to/character.jpg") +pose_video = load_video("path/to/pose_video.mp4") +face_video = load_video("path/to/face_video.mp4") + +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +image, height, width = aspect_ratio_resize(image, pipe) + +prompt = "A person dancing energetically in a studio" +negative_prompt = "blurry, low quality" + +# Advanced: Use temporal guidance and custom callback +def callback_fn(pipe, step_index, timestep, callback_kwargs): + # You can modify latents or other tensors here + print(f"Step {step_index}, Timestep {timestep}") + return callback_kwargs + +output = pipe( + image=image, + pose_video=pose_video, + face_video=face_video, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=81, + num_inference_steps=50, + guidance_scale=5.0, + num_frames_for_temporal_guidance=5, # Use 5 frames for temporal guidance (1 or 5 recommended) + callback_on_step_end=callback_fn, + callback_on_step_end_tensor_inputs=["latents"], +).frames[0] +export_to_video(output, "animated_advanced.mp4", fps=16) +``` + + + + +#### Key Parameters + +- **mode**: Choose between `"animation"` (default) or `"replacement"` +- **num_frames_for_temporal_guidance**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory +- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt +- **num_frames**: Total number of frames to generate. Should be divisible by `vae_scale_factor_temporal` (default: 4) + + ## Notes - Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`]. @@ -281,10 +496,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip # use "steamboat willie style" to trigger the LoRA prompt = """ - steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, - revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in - for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. - Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic + steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, + revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in + for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. + Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ @@ -359,6 +574,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip - all - __call__ +## WanAnimatePipeline + +[[autodoc]] WanAnimatePipeline + - all + - __call__ + ## WanPipelineOutput -[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput \ No newline at end of file +[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 39a364b07d78..06f87409262a 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -6,11 +6,20 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import load_file -from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel +from transformers import ( + AutoProcessor, + AutoTokenizer, + CLIPImageProcessor, + CLIPVisionModel, + CLIPVisionModelWithProjection, + UMT5EncoderModel, +) from diffusers import ( AutoencoderKLWan, UniPCMultistepScheduler, + WanAnimatePipeline, + WanAnimateTransformer3DModel, WanImageToVideoPipeline, WanPipeline, WanTransformer3DModel, @@ -105,8 +114,203 @@ "after_proj": "proj_out", } +ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "cross_attn.k_img": "attn2.to_k_img", + "cross_attn.v_img": "attn2.to_v_img", + "cross_attn.norm_k_img": "attn2.norm_k_img", + # After cross_attn -> attn2 rename, we need to rename the img keys + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", + # Wan Animate-specific mappings (motion encoder, face encoder, face adapter) + # Motion encoder mappings + # The name mapping is complicated for the convolutional part so we handle that in its own function + "motion_encoder.enc.fc": "motion_encoder.motion_network", + "motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight", + # Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten + "face_encoder.conv1_local.conv": "face_encoder.conv1_local", + "face_encoder.conv2.conv": "face_encoder.conv2", + "face_encoder.conv3.conv": "face_encoder.conv3", + # Face adapter mappings are handled in a separate function +} + + +# TODO: Verify this and simplify if possible. +def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None: + """ + Convert all motion encoder weights for Animate model. + + In the original model: + - All Linear layers in fc use EqualLinear + - All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately) + - Blur kernels are stored as buffers in Sequential modules + - ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)] + + Conversion strategy: + 1. Drop .kernel buffers (blur kernels) + 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu) + """ + # Skip if not a weight, bias, or kernel + if ".weight" not in key and ".bias" not in key and ".kernel" not in key: + return + + # Handle Blur kernel buffers from original implementation. + # After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel + # Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys + if ".kernel" in key and "motion_encoder" in key: + # Remove unexpected blur kernel buffers to avoid strict load errors + state_dict.pop(key, None) + return + + # Rename Sequential indices to named components in ConvLayer and ResBlock + if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key): + parts = key.split(".") + + # Find the sequential index (digit) after convs or after conv1/conv2/skip + # Examples: + # - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight) + # - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias) + # - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight) + # - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight + # - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias) + # - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias + # - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight) + # - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias) + # - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight) + # - enc.net_app.convs.8 -> conv_out (final conv layer) + + convs_idx = parts.index("convs") if "convs" in parts else -1 + if convs_idx >= 0 and len(parts) - convs_idx >= 2: + bias = False + # The nn.Sequential index will always follow convs + sequential_idx = int(parts[convs_idx + 1]) + if sequential_idx == 0: + if key.endswith(".weight"): + new_key = "motion_encoder.conv_in.weight" + elif key.endswith(".bias"): + new_key = "motion_encoder.conv_in.act_fn.bias" + bias = True + elif sequential_idx == final_conv_idx: + if key.endswith(".weight"): + new_key = "motion_encoder.conv_out.weight" + else: + # Intermediate .convs. layers, which get mapped to .res_blocks. + prefix = "motion_encoder.res_blocks." + + layer_name = parts[convs_idx + 2] + if layer_name == "skip": + layer_name = "conv_skip" + + if key.endswith(".weight"): + param_name = "weight" + elif key.endswith(".bias"): + param_name = "act_fn.bias" + bias = True + + suffix_parts = [str(sequential_idx - 1), layer_name, param_name] + suffix = ".".join(suffix_parts) + new_key = prefix + suffix + + param = state_dict.pop(key) + if bias: + param = param.squeeze() + state_dict[new_key] = param + return + return + return + + +def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None: + """ + Convert face adapter weights for the Animate model. + + The original model uses a fused KV projection but the diffusers models uses separate K and V projections. + """ + # Skip if not a weight or bias + if ".weight" not in key and ".bias" not in key: + return + + prefix = "face_adapter." + if ".fuser_blocks." in key: + parts = key.split(".") + + module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1 + if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3: + block_idx = parts[module_list_idx + 1] + layer_name = parts[module_list_idx + 2] + param_name = parts[module_list_idx + 3] + + if layer_name == "linear1_kv": + layer_name_k = "to_k" + layer_name_v = "to_v" + + suffix_k = ".".join([block_idx, layer_name_k, param_name]) + suffix_v = ".".join([block_idx, layer_name_v, param_name]) + new_key_k = prefix + suffix_k + new_key_v = prefix + suffix_v + + kv_proj = state_dict.pop(key) + k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0) + state_dict[new_key_k] = k_proj + state_dict[new_key_v] = v_proj + return + else: + if layer_name == "q_norm": + new_layer_name = "norm_q" + elif layer_name == "k_norm": + new_layer_name = "norm_k" + elif layer_name == "linear1_q": + new_layer_name = "to_q" + elif layer_name == "linear2": + new_layer_name = "to_out" + + suffix_parts = [block_idx, new_layer_name, param_name] + suffix = ".".join(suffix_parts) + new_key = prefix + suffix + state_dict[new_key] = state_dict.pop(key) + return + return + + TRANSFORMER_SPECIAL_KEYS_REMAP = {} VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} +ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "motion_encoder": convert_animate_motion_encoder_weights, + "face_adapter": convert_animate_face_adapter_weights, +} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: @@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-Animate-14B": + config = { + "model_id": "Wan-AI/Wan2.2-Animate-14B", + "diffusers_config": { + "image_dim": 1280, + "added_kv_proj_dim": 5120, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": (1, 2, 2), + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "rope_max_seq_len": 1024, + "pos_embed_seq_len": None, + "motion_encoder_size": 512, # Start of Wan Animate-specific configs + "motion_style_dim": 512, + "motion_dim": 20, + "motion_encoder_dim": 512, + "face_encoder_hidden_dim": 1024, + "face_encoder_num_heads": 4, + "inject_face_latents_blocks": 5, + }, + } + RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP return config, RENAME_DICT, SPECIAL_KEYS_REMAP @@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None): original_state_dict = load_sharded_safetensors(model_dir) with init_empty_weights(): - if "VACE" not in model_type: - transformer = WanTransformer3DModel.from_config(diffusers_config) - else: + if "Animate" in model_type: + transformer = WanAnimateTransformer3DModel.from_config(diffusers_config) + elif "VACE" in model_type: transformer = WanVACETransformer3DModel.from_config(diffusers_config) + else: + transformer = WanTransformer3DModel.from_config(diffusers_config) for key in list(original_state_dict.keys()): new_key = key[:] @@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None): continue handler_fn_inplace(key, original_state_dict) + # Load state dict into the meta model, which will materialize the tensors transformer.load_state_dict(original_state_dict, strict=True, assign=True) + + # Move to CPU to ensure all tensors are materialized + transformer = transformer.to("cpu") + return transformer @@ -926,7 +1168,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - if "Wan2.2" in args.model_type and "TI2V" not in args.model_type: + if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type: transformer = convert_transformer(args.model_type, stage="high_noise_model") transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") else: @@ -942,7 +1184,7 @@ def get_args(): tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") if "FLF2V" in args.model_type: flow_shift = 16.0 - elif "TI2V" in args.model_type: + elif "TI2V" in args.model_type or "Animate" in args.model_type: flow_shift = 5.0 else: flow_shift = 3.0 @@ -954,6 +1196,8 @@ def get_args(): if args.dtype != "none": dtype = DTYPE_MAPPING[args.dtype] transformer.to(dtype) + if transformer_2 is not None: + transformer_2.to(dtype) if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type: pipe = WanImageToVideoPipeline( @@ -1016,6 +1260,21 @@ def get_args(): vae=vae, scheduler=scheduler, ) + elif "Animate" in args.model_type: + image_encoder = CLIPVisionModel.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 + ) + image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + + pipe = WanAnimatePipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) else: pipe = WanPipeline( transformer=transformer, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a5040bd28394..02df34c07e8e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -268,6 +268,7 @@ "UNetSpatioTemporalConditionModel", "UVit2DModel", "VQModel", + "WanAnimateTransformer3DModel", "WanTransformer3DModel", "WanVACETransformer3DModel", "attention_backend", @@ -636,6 +637,7 @@ "VisualClozeGenerationPipeline", "VisualClozePipeline", "VQDiffusionPipeline", + "WanAnimatePipeline", "WanImageToVideoPipeline", "WanPipeline", "WanVACEPipeline", @@ -977,6 +979,7 @@ UNetSpatioTemporalConditionModel, UVit2DModel, VQModel, + WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, attention_backend, @@ -1315,6 +1318,7 @@ VisualClozeGenerationPipeline, VisualClozePipeline, VQDiffusionPipeline, + WanAnimatePipeline, WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 067d876ffcd8..abd0a25819f5 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -409,7 +409,7 @@ def _resize_and_fill( src_w = width if ratio < src_ratio else image.width * height // image.height src_h = height if ratio >= src_ratio else image.height * width // image.width - resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample]) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) @@ -460,7 +460,7 @@ def _resize_and_crop( src_w = width if ratio > src_ratio else image.width * height // image.height src_h = height if ratio <= src_ratio else image.height * width // image.width - resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample]) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) return res diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e97ab8bd1d2a..b42e981f71a9 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -108,6 +108,7 @@ _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] + _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] @@ -214,6 +215,7 @@ T5FilmDecoder, Transformer2DModel, TransformerTemporalModel, + WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 66daf56e23b2..826469237fb1 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -42,4 +42,5 @@ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel + from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index 424d9ff9d360..f9fc971950b0 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -188,6 +188,11 @@ def __init__( h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim + + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 freqs_cos = [] @@ -213,11 +218,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - split_sizes = [ - self.attention_head_dim - 2 * (self.attention_head_dim // 3), - self.attention_head_dim // 3, - self.attention_head_dim // 3, - ] + split_sizes = [self.t_dim, self.h_dim, self.w_dim] freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1) diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py new file mode 100644 index 000000000000..6a47a67385a3 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_wan_animate.py @@ -0,0 +1,1298 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = { + "4": 512, + "8": 512, + "16": 512, + "32": 512, + "64": 256, + "128": 128, + "256": 64, + "512": 32, + "1024": 16, +} + + +# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections +def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections +def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img + + +class FusedLeakyReLU(nn.Module): + """ + Fused LeakyRelu with scale factor and channel-wise bias. + """ + + def __init__(self, negative_slope: float = 0.2, scale: float = 2**0.5, bias_channels: Optional[int] = None): + super().__init__() + self.negative_slope = negative_slope + self.scale = scale + self.channels = bias_channels + + if self.channels is not None: + self.bias = nn.Parameter( + torch.zeros( + self.channels, + ) + ) + else: + self.bias = None + + def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + if self.bias is not None: + # Expand self.bias to have all singleton dims except at self.channel_dim + expanded_shape = [1] * x.ndim + expanded_shape[channel_dim] = self.bias.shape[0] + bias = self.bias.reshape(*expanded_shape) + x = x + bias + return F.leaky_relu(x, self.negative_slope) * self.scale + + +class MotionConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + blur_kernel: Optional[Tuple[int, ...]] = None, + blur_upsample_factor: int = 1, + use_activation: bool = True, + ): + super().__init__() + self.use_activation = use_activation + self.in_channels = in_channels + + # Handle blurring (applying a FIR filter with the given kernel) if available + self.blur = False + if blur_kernel is not None: + p = (len(blur_kernel) - stride) + (kernel_size - 1) + self.blur_padding = ((p + 1) // 2, p // 2) + + kernel = torch.tensor(blur_kernel) + # Convert kernel to 2D if necessary + if kernel.ndim == 1: + kernel = kernel[None, :] * kernel[:, None] + # Normalize kernel + kernel = kernel / kernel.sum() + if blur_upsample_factor > 1: + kernel = kernel * (blur_upsample_factor**2) + self.register_buffer("blur_kernel", kernel, persistent=False) + self.blur = True + + # Main Conv2d parameters (with scale factor) + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.stride = stride + self.padding = padding + + # If using an activation function, the bias will be fused into the activation + if bias and not self.use_activation: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FusedLeakyReLU(bias_channels=out_channels) + else: + self.act_fn = None + + def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + # Apply blur if using + if self.blur: + # NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates + # set to 1, which should be equivalent to a 2D convolution + expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1) + x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels) + + # Main Conv2D with scaling + x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + # Activation with fused bias, if using + if self.use_activation: + x = self.act_fn(x, channel_dim=channel_dim) + return x + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" kernel_size={self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class MotionLinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + use_activation: bool = False, + ): + super().__init__() + self.use_activation = use_activation + + # Linear weight with scale factor + self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) + self.scale = 1 / math.sqrt(in_dim) + + # If an activation is present, the bias will be fused to it + if bias and not self.use_activation: + self.bias = nn.Parameter(torch.zeros(out_dim)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FusedLeakyReLU(bias_channels=out_dim) + else: + self.act_fn = None + + def forward(self, input: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + out = F.linear(input, self.weight * self.scale, bias=self.bias) + if self.use_activation: + out = self.act_fn(out, channel_dim=channel_dim) + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}," + f" bias={self.bias is not None})" + ) + + +class MotionEncoderResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + kernel_size_skip: int = 1, + blur_kernel: Tuple[int, ...] = (1, 3, 3, 1), + downsample_factor: int = 2, + ): + super().__init__() + self.downsample_factor = downsample_factor + + # 3 x 3 Conv + fused leaky ReLU + self.conv1 = MotionConv2d( + in_channels, + in_channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + use_activation=True, + ) + + # 3 x 3 Conv that downsamples 2x + fused leaky ReLU + self.conv2 = MotionConv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=self.downsample_factor, + padding=0, + blur_kernel=blur_kernel, + use_activation=True, + ) + + # 1 x 1 Conv that downsamples 2x in skip connection + self.conv_skip = MotionConv2d( + in_channels, + out_channels, + kernel_size=kernel_size_skip, + stride=self.downsample_factor, + padding=0, + bias=False, + blur_kernel=blur_kernel, + use_activation=False, + ) + + def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + x_out = self.conv1(x, channel_dim) + x_out = self.conv2(x_out, channel_dim) + + x_skip = self.conv_skip(x, channel_dim) + + x_out = (x_out + x_skip) / math.sqrt(2) + return x_out + + +class WanAnimateMotionEncoder(nn.Module): + def __init__( + self, + size: int = 512, + style_dim: int = 512, + motion_dim: int = 20, + out_dim: int = 512, + motion_blocks: int = 5, + channels: Optional[Dict[str, int]] = None, + ): + super().__init__() + self.size = size + + # Appearance encoder: conv layers + if channels is None: + channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES + + self.conv_in = MotionConv2d(3, channels[str(size)], 1, use_activation=True) + + self.res_blocks = nn.ModuleList() + in_channels = channels[str(size)] + log_size = int(math.log(size, 2)) + for i in range(log_size, 2, -1): + out_channels = channels[str(2 ** (i - 1))] + self.res_blocks.append(MotionEncoderResBlock(in_channels, out_channels)) + in_channels = out_channels + + self.conv_out = MotionConv2d(in_channels, style_dim, 4, padding=0, bias=False, use_activation=False) + + # Motion encoder: linear layers + # NOTE: there are no activations in between the linear layers here, which is weird but I believe matches the + # original code. + linears = [MotionLinear(style_dim, style_dim) for _ in range(motion_blocks - 1)] + linears.append(MotionLinear(style_dim, motion_dim)) + self.motion_network = nn.ModuleList(linears) + + self.motion_synthesis_weight = nn.Parameter(torch.randn(out_dim, motion_dim)) + + def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + if (face_image.shape[-2] != self.size) or (face_image.shape[-1] != self.size): + raise ValueError( + f"Face pixel values has resolution ({face_image.shape[-1]}, {face_image.shape[-2]}) but is expected" + f" to have resolution ({self.size}, {self.size})" + ) + + # Appearance encoding through convs + face_image = self.conv_in(face_image, channel_dim) + for block in self.res_blocks: + face_image = block(face_image, channel_dim) + face_image = self.conv_out(face_image, channel_dim) + motion_feat = face_image.squeeze(-1).squeeze(-1) + + # Motion feature extraction + for linear_layer in self.motion_network: + motion_feat = linear_layer(motion_feat, channel_dim=channel_dim) + + # Motion synthesis via Linear Motion Decomposition + weight = self.motion_synthesis_weight + 1e-8 + # Upcast the QR orthogonalization operation to FP32 + original_motion_dtype = motion_feat.dtype + motion_feat = motion_feat.to(torch.float32) + weight = weight.to(torch.float32) + + Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device) + + motion_feat_diag = torch.diag_embed(motion_feat) # Alpha, diagonal matrix + motion_decomposition = torch.matmul(motion_feat_diag, Q.T) + motion_vec = torch.sum(motion_decomposition, dim=1) + + motion_vec = motion_vec.to(dtype=original_motion_dtype) + + return motion_vec + + +class WanAnimateFaceEncoder(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + hidden_dim: int = 1024, + num_heads: int = 4, + kernel_size: int = 3, + eps: float = 1e-6, + pad_mode: str = "replicate", + ): + super().__init__() + self.num_heads = num_heads + self.time_causal_padding = (kernel_size - 1, 0) + self.pad_mode = pad_mode + + self.act = nn.SiLU() + + self.conv1_local = nn.Conv1d(in_dim, hidden_dim * num_heads, kernel_size=kernel_size, stride=1) + self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2) + self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2) + + self.norm1 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + + self.out_proj = nn.Linear(hidden_dim, out_dim) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, out_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] + + # Reshape to channels-first to apply causal Conv1d over frame dim + x = x.permute(0, 2, 1) + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + x = self.conv1_local(x) # [B, C, T_padded] --> [B, N * C, T] + x = x.unflatten(1, (self.num_heads, -1)).flatten(0, 1) # [B, N * C, T] --> [B * N, C, T] + # Reshape back to channels-last to apply LayerNorm over channel dim + x = x.permute(0, 2, 1) + x = self.norm1(x) + x = self.act(x) + + x = x.permute(0, 2, 1) + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + x = self.conv2(x) + x = x.permute(0, 2, 1) + x = self.norm2(x) + x = self.act(x) + + x = x.permute(0, 2, 1) + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + x = self.conv3(x) + x = x.permute(0, 2, 1) + x = self.norm3(x) + x = self.act(x) + + x = self.out_proj(x) + x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # [B * N, T, C_out] --> [B, T, N, C_out] + + padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1).to(device=x.device) + x = torch.cat([x, padding], dim=-2) # [B, T, N, C_out] --> [B, T, N + 1, C_out] + + return x + + +class WanAnimateFaceBlockAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or" + f" higher." + ) + + def __call__( + self, + attn: "WanAnimateFaceBlockCrossAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # encoder_hidden_states corresponds to the motion vec + # attention_mask corresponds to the motion mask (if any) + hidden_states = attn.pre_norm_q(hidden_states) + encoder_hidden_states = attn.pre_norm_kv(encoder_hidden_states) + + # B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim + B, T, N, C = encoder_hidden_states.shape + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D] + key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv] + value = value.view(B, T, N, attn.heads, -1) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + # NOTE: the below line (which follows the official code) means that in practice, the number of frames T in + # encoder_hidden_states (the motion vector after applying the face encoder) must evenly divide the + # post-patchify sequence length S of the transformer hidden_states. Is it possible to remove this dependency? + query = query.unflatten(1, (T, -1)).flatten(0, 1) # [B, S, H, D] --> [B * T, S / T, H, D] + key = key.flatten(0, 1) # [B, T, N, H, D_kv] --> [B * T, N, H, D_kv] + value = value.flatten(0, 1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + hidden_states = hidden_states.unflatten(0, (B, T)).flatten(1, 2) + + hidden_states = attn.to_out(hidden_states) + + if attention_mask is not None: + # NOTE: attention_mask is assumed to be a multiplicative mask + attention_mask = attention_mask.flatten(start_dim=1) + hidden_states = hidden_states * attention_mask + + return hidden_states + + +class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin): + """ + Temporally-aligned cross attention with the face motion signal in the Wan Animate Face Blocks. + """ + + _default_processor_cls = WanAnimateFaceBlockAttnProcessor + _available_processors = [WanAnimateFaceBlockAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-6, + cross_attention_dim_head: Optional[int] = None, + processor=None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.cross_attention_head_dim = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + # 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector). + # NOTE: this is not used in "vanilla" WanAttention + self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False) + self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False) + + # 2. QKV and Output Projections + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True) + + # 3. QK Norm + # NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True) + + # 4. Set attention processor + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask) + + +# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor +class WanAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +# Copied from diffusers.models.transformers.transformer_wan.WanAttention +class WanAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = WanAttnProcessor + _available_processors = [WanAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + is_cross_attention=None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(dropout), + ] + ) + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + self.is_cross_attention = cross_attention_dim_head is not None + + self.set_processor(processor) + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if self.cross_attention_dim_head is None: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + + +# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding +class WanImageEmbedding(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = FP32LayerNorm(out_features) + if pos_embed_seq_len is not None: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) + else: + self.pos_embed = None + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + if self.pos_embed is not None: + batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape + encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) + encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed + + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding +class WanTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + timestep_seq_len: Optional[int] = None, + ): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (-1, timestep_seq_len)) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed +class WanRotaryPosEmbed(nn.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, + ) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + split_sizes = [self.t_dim, self.h_dim, self.w_dim] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + + return freqs_cos, freqs_sin + + +# Copied from diffusers.models.transformers.transformer_wan.WanTransformerBlock +class WanTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + processor=WanAttnProcessor(), + ) + + # 2. Cross-attention + self.attn2 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + processor=WanAttnProcessor(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class WanAnimateTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + r""" + A Transformer model for video-like data used in the WanAnimate model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + image_dim (`int`, *optional*, defaults to `1280`): + The number of channels to use for the image embedding. If `None`, no projection is used. + added_kv_proj_dim (`int`, *optional*, defaults to `5120`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["WanTransformerBlock", "MotionEncoderResBlock"] + _keep_in_fp32_modules = [ + "time_embedder", + "scale_shift_table", + "norm1", + "norm2", + "norm3", + "motion_synthesis_weight", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["WanTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: Optional[int] = 36, + latent_channels: Optional[int] = 16, + out_channels: Optional[int] = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = 1280, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, # Start of Wan Animate-specific args + motion_encoder_size: int = 512, + motion_style_dim: int = 512, + motion_dim: int = 20, + motion_encoder_dim: int = 512, + face_encoder_hidden_dim: int = 1024, + face_encoder_num_heads: int = 4, + inject_face_latents_blocks: int = 5, + motion_encoder_batch_size: int = 8, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + # Allow either only in_channels or only latent_channels to be set for convenience + if in_channels is None and latent_channels is not None: + in_channels = 2 * latent_channels + 4 + elif in_channels is not None and latent_channels is None: + latent_channels = (in_channels - 4) // 2 + elif in_channels is not None and latent_channels is not None: + # TODO: should this always be true? + assert in_channels == 2 * latent_channels + 4, "in_channels should be 2 * latent_channels + 4" + else: + raise ValueError("At least one of `in_channels` and `latent_channels` must be supplied.") + out_channels = out_channels or latent_channels + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.pose_patch_embedding = nn.Conv3d(latent_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # Motion encoder + self.motion_encoder = WanAnimateMotionEncoder( + size=motion_encoder_size, + style_dim=motion_style_dim, + motion_dim=motion_dim, + out_dim=motion_encoder_dim, + channels=motion_encoder_channel_sizes, + ) + + # Face encoder + self.face_encoder = WanAnimateFaceEncoder( + in_dim=motion_encoder_dim, + out_dim=inner_dim, + hidden_dim=face_encoder_hidden_dim, + num_heads=face_encoder_num_heads, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + ) + for _ in range(num_layers) + ] + ) + + self.face_adapter = nn.ModuleList( + [ + WanAnimateFaceBlockCrossAttention( + dim=inner_dim, + heads=num_attention_heads, + dim_head=inner_dim // num_attention_heads, + eps=eps, + cross_attention_dim_head=inner_dim // num_attention_heads, + processor=WanAnimateFaceBlockAttnProcessor(), + ) + for _ in range(num_layers // inject_face_latents_blocks) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + pose_hidden_states: Optional[torch.Tensor] = None, + face_pixel_values: Optional[torch.Tensor] = None, + motion_encode_batch_size: Optional[int] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Forward pass of Wan2.2-Animate transformer model. + + Args: + hidden_states (`torch.Tensor` of shape `(B, 2C + 4, T + 1, H, W)`): + Input noisy video latents of shape `(B, 2C + 4, T + 1, H, W)`, where B is the batch size, C is the + number of latent channels (16 for Wan VAE), T is the number of latent frames in an inference segment, H + is the latent height, and W is the latent width. + timestep: (`torch.LongTensor`): + The current timestep in the denoising loop. + encoder_hidden_states (`torch.Tensor`): + Text embeddings from the text encoder (umT5 for Wan Animate). + encoder_hidden_states_image (`torch.Tensor`): + CLIP visual features of the reference (character) image. + pose_hidden_states (`torch.Tensor` of shape `(B, C, T, H, W)`): + Pose video latents. TODO: description + face_pixel_values (`torch.Tensor` of shape `(B, C', S, H', W')`): + Face video in pixel space (not latent space). Typically C' = 3 and H' and W' are the height/width of + the face video in pixels. Here S is the inference segment length, usually set to 77. + motion_encode_batch_size (`int`, *optional*): + The batch size for batched encoding of the face video via the motion encoder. Will default to + `self.config.motion_encoder_batch_size` if not set. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return the output as a dict or tuple. + """ + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # Check that shapes match up + if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]: + raise ValueError( + f"pose_hidden_states frame dim (dim 2) is {pose_hidden_states.shape[2]} but must be one less than the" + f" hidden_states's corresponding frame dim: {hidden_states.shape[2]}" + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. Rotary position embedding + rotary_emb = self.rope(hidden_states) + + # 2. Patch embedding + hidden_states = self.patch_embedding(hidden_states) + pose_hidden_states = self.pose_patch_embedding(pose_hidden_states) + # Add pose embeddings to hidden states + hidden_states[:, :, 1:] = hidden_states[:, :, 1:] + pose_hidden_states + # Calling contiguous() here is important so that we don't recompile when performing regional compilation + hidden_states = hidden_states.flatten(2).transpose(1, 2).contiguous() + + # 3. Condition embeddings (time, text, image) + # Wan Animate is based on Wan 2.1 and thus uses Wan 2.1's timestep logic + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=None + ) + + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Get motion features from the face video + # Motion vector computation from face pixel values + batch_size, channels, num_face_frames, height, width = face_pixel_values.shape + # Rearrange from (B, C, T, H, W) to (B*T, C, H, W) + face_pixel_values = face_pixel_values.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) + + # Extract motion features using motion encoder + # Perform batched motion encoder inference to allow trading off inference speed for memory usage + motion_encode_batch_size = motion_encode_batch_size or self.config.motion_encoder_batch_size + face_batches = torch.split(face_pixel_values, motion_encode_batch_size) + motion_vec_batches = [] + for face_batch in face_batches: + motion_vec_batch = self.motion_encoder(face_batch) + motion_vec_batches.append(motion_vec_batch) + motion_vec = torch.cat(motion_vec_batches) + motion_vec = motion_vec.view(batch_size, num_face_frames, -1) + + # Now get face features from the motion vector + motion_vec = self.face_encoder(motion_vec) + + # Add padding at the beginning (prepend zeros) + pad_face = torch.zeros_like(motion_vec[:, :1]) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + + # 5. Transformer blocks with face adapter integration + for block_idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...) + if block_idx % self.config.inject_face_latents_blocks == 0: + face_adapter_block_idx = block_idx // self.config.inject_face_latents_blocks + face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec) + # In case the face adapter and main transformer blocks are on different devices, which can happen when + # using model parallelism + face_adapter_output = face_adapter_output.to(device=hidden_states.device) + hidden_states = face_adapter_output + hidden_states + + # 6. Output norm, projection & unpatchify + # batch_size, inner_dim + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + hidden_states_original_dtype = hidden_states.dtype + hidden_states = self.norm_out(hidden_states.float()) + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + hidden_states = (hidden_states * (1 + scale) + shift).to(dtype=hidden_states_original_dtype) + + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 495753041f10..719ff4c7df15 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -385,7 +385,13 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", ] - _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"] + _import_structure["wan"] = [ + "WanPipeline", + "WanImageToVideoPipeline", + "WanVideoToVideoPipeline", + "WanVACEPipeline", + "WanAnimatePipeline", + ] _import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", @@ -803,7 +809,13 @@ UniDiffuserTextDecoder, ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline - from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline + from .wan import ( + WanAnimatePipeline, + WanImageToVideoPipeline, + WanPipeline, + WanVACEPipeline, + WanVideoToVideoPipeline, + ) from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py index bb96372b1db2..ad51a52f9242 100644 --- a/src/diffusers/pipelines/wan/__init__.py +++ b/src/diffusers/pipelines/wan/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_wan"] = ["WanPipeline"] + _import_structure["pipeline_wan_animate"] = ["WanAnimatePipeline"] _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"] _import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"] _import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"] @@ -35,10 +36,10 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_wan import WanPipeline + from .pipeline_wan_animate import WanAnimatePipeline from .pipeline_wan_i2v import WanImageToVideoPipeline from .pipeline_wan_vace import WanVACEPipeline from .pipeline_wan_video2video import WanVideoToVideoPipeline - else: import sys diff --git a/src/diffusers/pipelines/wan/image_processor.py b/src/diffusers/pipelines/wan/image_processor.py new file mode 100644 index 000000000000..b1594d08630f --- /dev/null +++ b/src/diffusers/pipelines/wan/image_processor.py @@ -0,0 +1,185 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from ...configuration_utils import register_to_config +from ...image_processor import VaeImageProcessor +from ...utils import PIL_INTERPOLATION + + +class WanAnimateImageProcessor(VaeImageProcessor): + r""" + Image processor to preprocess the reference (character) image for the Wan Animate model. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `16`): + VAE latent channels. + spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(2, 2)`): + The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2). + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + fill_color (`str` or `float` or `Tuple[float, ...]`, *optional*, defaults to `None`): + An optional fill color when `resize_mode` is set to `"fill"`. This will fill the empty space with that + color instead of filling with data from the image. Any valid `color` argument to `PIL.Image.new` is valid; + if `None`, will default to filling with data from `image`. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + vae_latent_channels: int = 16, + spatial_patch_size: Tuple[int, int] = (2, 2), + resample: str = "lanczos", + reducing_gap: int = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, + fill_color: Optional[Union[str, float, Tuple[float, ...]]] = 0, + ): + super().__init__() + if do_convert_rgb and do_convert_grayscale: + raise ValueError( + "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," + " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", + " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", + ) + + def _resize_and_fill( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + r""" + Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, filling empty with data from image. + + Args: + image (`PIL.Image.Image`): + The image to resize and fill. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and filled image. + """ + + ratio = width / height + src_ratio = image.width / image.height + fill_with_image_data = self.config.fill_color is None + fill_color = self.config.fill_color or 0 + + src_w = width if ratio < src_ratio else image.width * height // image.height + src_h = height if ratio >= src_ratio else image.height * width // image.width + + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample]) + res = PIL.Image.new("RGB", (width, height), color=fill_color) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if fill_with_image_data: + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + if fill_height > 0: + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste( + resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), + box=(0, fill_height + src_h), + ) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + if fill_width > 0: + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste( + resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), + box=(fill_width + src_w, 0), + ) + + return res + + def get_default_height_width( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + ) -> Tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor * spatial_patch_size`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + max_area = width * height + aspect_ratio = height / width + mod_value_h = self.config.vae_scale_factor * self.config.spatial_patch_size[0] + mod_value_w = self.config.vae_scale_factor * self.config.spatial_patch_size[1] + + # Try to preserve the aspect ratio + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_h * mod_value_h + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_w * mod_value_w + + return height, width diff --git a/src/diffusers/pipelines/wan/pipeline_wan_animate.py b/src/diffusers/pipelines/wan/pipeline_wan_animate.py new file mode 100644 index 000000000000..c7c983b2f7d4 --- /dev/null +++ b/src/diffusers/pipelines/wan/pipeline_wan_animate.py @@ -0,0 +1,1204 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL +import regex as re +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanAnimateTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import WanAnimateImageProcessor +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> import numpy as np + >>> from diffusers import WanAnimatePipeline + >>> from diffusers.utils import export_to_video, load_image, load_video + + >>> model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" + >>> pipe = WanAnimatePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> # Optionally upcast the Wan VAE to FP32 + >>> pipe.vae.to(torch.float32) + >>> pipe.to("cuda") + + >>> # Load the reference character image + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + + >>> # Load pose and face videos (preprocessed from reference video) + >>> # Note: Videos should be preprocessed to extract pose keypoints and face features + >>> # Refer to the Wan-Animate preprocessing documentation for details + >>> pose_video = load_video("path/to/pose_video.mp4") + >>> face_video = load_video("path/to/face_video.mp4") + + >>> # CFG is generally not used for Wan Animate + >>> prompt = ( + ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " + ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + ... ) + + >>> # Animation mode: Animate the character with the motion from pose/face videos + >>> output = pipe( + ... image=image, + ... pose_video=pose_video, + ... face_video=face_video, + ... prompt=prompt, + ... height=height, + ... width=width, + ... segment_frame_length=77, # Frame length of each inference segment + ... guidance_scale=1.0, + ... num_inference_steps=20, + ... mode="animate", + ... ).frames[0] + >>> export_to_video(output, "output_animation.mp4", fps=30) + + >>> # Replacement mode: Replace a character in the background video + >>> # Requires additional background_video and mask_video inputs + >>> background_video = load_video("path/to/background_video.mp4") + >>> mask_video = load_video("path/to/mask_video.mp4") # Black areas preserved, white areas generated + >>> output = pipe( + ... image=image, + ... pose_video=pose_video, + ... face_video=face_video, + ... background_video=background_video, + ... mask_video=mask_video, + ... prompt=prompt, + ... height=height, + ... width=width, + ... segment_frame_length=77, # Frame length of each inference segment + ... guidance_scale=1.0, + ... num_inference_steps=20, + ... mode="replace", + ... ).frames[0] + >>> export_to_video(output, "output_replacement.mp4", fps=30) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanAnimatePipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for unified character animation and replacement using Wan-Animate. + + WanAnimatePipeline takes a character image, pose video, and face video as input, and generates a video in two + modes: + + 1. **Animation mode**: The model generates a video of the character image that mimics the human motion in the input + pose and face videos. The character is animated based on the provided motion controls, creating a new animated + video of the character. + + 2. **Replacement mode**: The model replaces a character in a background video with the provided character image, + using the pose and face videos for motion control. This mode requires additional `background_video` and + `mask_video` inputs. The mask video should have black regions where the original content should be preserved and + white regions where the new character should be generated. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.WanLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`WanAnimateTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + image_processor ([`CLIPImageProcessor`]): + Image processor for preprocessing images before encoding. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + transformer: WanAnimateTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor_for_mask = VideoProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_convert_grayscale=True + ) + # In case self.transformer is None (e.g. for some pipeline tests) + spatial_patch_size = self.transformer.config.patch_size[-2:] if self.transformer is not None else (2, 2) + self.vae_image_processor = WanAnimateImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + spatial_patch_size=spatial_patch_size, + resample="bilinear", + fill_color=0, + ) + self.image_processor = image_processor + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + mode=None, + prev_segment_conditioning_frames=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if pose_video is None: + raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.") + if face_video is None: + raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.") + if not isinstance(pose_video, list) or not isinstance(face_video, list): + raise ValueError("`pose_video` and `face_video` must be lists of PIL images.") + if len(pose_video) == 0 or len(face_video) == 0: + raise ValueError("`pose_video` and `face_video` must contain at least one frame.") + if mode == "replace" and (background_video is None or mask_video is None): + raise ValueError( + "Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video`" + " undefined when mode is `replace`." + ) + if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)): + raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found" + f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if mode is not None and (not isinstance(mode, str) or mode not in ("animate", "replace")): + raise ValueError( + f"`mode` has to be of type `str` and in ('animate', 'replace') but its type is {type(mode)} and value is {mode}" + ) + + if prev_segment_conditioning_frames is not None and ( + not isinstance(prev_segment_conditioning_frames, int) or prev_segment_conditioning_frames not in (1, 5) + ): + raise ValueError( + f"`prev_segment_conditioning_frames` has to be of type `int` and 1 or 5 but its type is" + f" {type(prev_segment_conditioning_frames)} and value is {prev_segment_conditioning_frames}" + ) + + def get_i2v_mask( + self, + batch_size: int, + latent_t: int, + latent_h: int, + latent_w: int, + mask_len: int = 1, + mask_pixel_values: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + device: Union[str, torch.device] = "cuda", + ) -> torch.Tensor: + # mask_pixel_values shape (if supplied): [B, C = 1, T, latent_h, latent_w] + if mask_pixel_values is None: + mask_lat_size = torch.zeros( + batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, dtype=dtype, device=device + ) + else: + mask_lat_size = mask_pixel_values.clone().to(device=device, dtype=dtype) + mask_lat_size[:, :, :mask_len] = 1 + first_frame_mask = mask_lat_size[:, :, 0:1] + # Repeat first frame mask self.vae_scale_factor_temporal (= 4) times in the frame dimension + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, self.vae_scale_factor_temporal, latent_h, latent_w + ).transpose(1, 2) # [B, C = 1, 4 * T_lat, H_lat, W_lat] --> [B, C = 4, T_lat, H_lat, W_lat] + + return mask_lat_size + + def prepare_reference_image_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + sample_mode: int = "argmax", + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + # image shape: (B, C, H, W) or (B, C, T, H, W) + dtype = dtype or self.vae.dtype + if image.ndim == 4: + # Add a singleton frame dimension after the channels dimension + image = image.unsqueeze(2) + + _, _, _, height, width = image.shape + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + # Encode image to latents using VAE + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + # Like in prepare_latents, assume len(generator) == batch_size + ref_image_latents = [ + retrieve_latents(self.vae.encode(image), generator=g, sample_mode=sample_mode) for g in generator + ] + ref_image_latents = torch.cat(ref_image_latents) + else: + ref_image_latents = retrieve_latents(self.vae.encode(image), generator, sample_mode) + # Standardize latents in preparation for Wan VAE encode + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(ref_image_latents.device, ref_image_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + ref_image_latents.device, ref_image_latents.dtype + ) + ref_image_latents = (ref_image_latents - latents_mean) * latents_recip_std + # Handle the case where we supply one image and one generator, but batch_size > 1 (e.g. generating multiple + # videos per prompt) + if ref_image_latents.shape[0] == 1 and batch_size > 1: + ref_image_latents = ref_image_latents.expand(batch_size, -1, -1, -1, -1) + + # Prepare I2V mask in latent space and prepend to the reference image latents along channel dim + reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, dtype, device) + reference_image_latents = torch.cat([reference_image_mask, ref_image_latents], dim=1) + + return reference_image_latents + + def prepare_prev_segment_cond_latents( + self, + prev_segment_cond_video: Optional[torch.Tensor] = None, + background_video: Optional[torch.Tensor] = None, + mask_video: Optional[torch.Tensor] = None, + batch_size: int = 1, + segment_frame_length: int = 77, + start_frame: int = 0, + height: int = 720, + width: int = 1280, + prev_segment_cond_frames: int = 1, + task: str = "animate", + interpolation_mode: str = "bicubic", + sample_mode: str = "argmax", + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + # prev_segment_cond_video shape: (B, C, T, H, W) in pixel space if supplied + # background_video shape: (B, C, T, H, W) (same as prev_segment_cond_video shape) + # mask_video shape: (B, 1, T, H, W) (same as prev_segment_cond_video, but with only 1 channel) + dtype = dtype or self.vae.dtype + if prev_segment_cond_video is None: + if task == "replace": + prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames].to(dtype) + else: + cond_frames_shape = (batch_size, 3, prev_segment_cond_frames, height, width) # In pixel space + prev_segment_cond_video = torch.zeros(cond_frames_shape, dtype=dtype, device=device) + + data_batch_size, channels, _, segment_height, segment_width = prev_segment_cond_video.shape + num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + if segment_height != height or segment_width != width: + print( + f"Interpolating prev segment cond video from ({segment_width}, {segment_height}) to ({width}, {height})" + ) + # Perform a 4D (spatial) rather than a 5D (spatiotemporal) reshape, following the original code + prev_segment_cond_video = prev_segment_cond_video.transpose(1, 2).flatten(0, 1) # [B * T, C, H, W] + prev_segment_cond_video = F.interpolate( + prev_segment_cond_video, size=(height, width), mode=interpolation_mode + ) + prev_segment_cond_video = prev_segment_cond_video.unflatten(0, (batch_size, -1)).transpose(1, 2) + + # Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if + # replacing). + if task == "replace": + remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype) + else: + remaining_segment_frames = segment_frame_length - prev_segment_cond_frames + remaining_segment = torch.zeros( + batch_size, channels, remaining_segment_frames, height, width, dtype=dtype, device=device + ) + + # Prepend the conditioning frames from the previous segment to the remaining segment video in the frame dim + prev_segment_cond_video = prev_segment_cond_video.to(dtype=dtype) + full_segment_cond_video = torch.cat([prev_segment_cond_video, remaining_segment], dim=2) + + if isinstance(generator, list): + if data_batch_size == len(generator): + prev_segment_cond_latents = [ + retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode) + for i, g in enumerate(generator) + ] + elif data_batch_size == 1: + # Like prepare_latents, assume len(generator) == batch_size + prev_segment_cond_latents = [ + retrieve_latents(self.vae.encode(full_segment_cond_video), g, sample_mode) for g in generator + ] + else: + raise ValueError( + f"The batch size of the prev segment video should be either {len(generator)} or 1 but is" + f" {data_batch_size}" + ) + prev_segment_cond_latents = torch.cat(prev_segment_cond_latents) + else: + prev_segment_cond_latents = retrieve_latents( + self.vae.encode(full_segment_cond_video), generator, sample_mode + ) + # Standardize latents in preparation for Wan VAE encode + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(prev_segment_cond_latents.device, prev_segment_cond_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + prev_segment_cond_latents.device, prev_segment_cond_latents.dtype + ) + prev_segment_cond_latents = (prev_segment_cond_latents - latents_mean) * latents_recip_std + + # Prepare I2V mask + if task == "replace": + mask_video = 1 - mask_video + mask_video = mask_video.permute(0, 2, 1, 3, 4) + mask_video = mask_video.flatten(0, 1) + mask_video = F.interpolate(mask_video, size=(latent_height, latent_width), mode="nearest") + mask_pixel_values = mask_video.unflatten(0, (batch_size, -1)) + mask_pixel_values = mask_pixel_values.permute(0, 2, 1, 3, 4) # output shape: [B, C = 1, T, H_lat, W_lat] + else: + mask_pixel_values = None + prev_segment_cond_mask = self.get_i2v_mask( + batch_size, + num_latent_frames, + latent_height, + latent_width, + mask_len=prev_segment_cond_frames if start_frame > 0 else 0, + mask_pixel_values=mask_pixel_values, + dtype=dtype, + device=device, + ) + + # Prepend cond I2V mask to prev segment cond latents along channel dimension + prev_segment_cond_latents = torch.cat([prev_segment_cond_mask, prev_segment_cond_latents], dim=1) + return prev_segment_cond_latents + + def prepare_pose_latents( + self, + pose_video: torch.Tensor, + batch_size: int = 1, + sample_mode: int = "argmax", + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + # pose_video shape: (B, C, T, H, W) + pose_video = pose_video.to(device=device, dtype=dtype if dtype is not None else self.vae.dtype) + if isinstance(generator, list): + pose_latents = [ + retrieve_latents(self.vae.encode(pose_video), generator=g, sample_mode=sample_mode) for g in generator + ] + pose_latents = torch.cat(pose_latents) + else: + pose_latents = retrieve_latents(self.vae.encode(pose_video), generator, sample_mode) + # Standardize latents in preparation for Wan VAE encode + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(pose_latents.device, pose_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + pose_latents.device, pose_latents.dtype + ) + pose_latents = (pose_latents - latents_mean) * latents_recip_std + if pose_latents.shape[0] == 1 and batch_size > 1: + pose_latents = pose_latents.expand(batch_size, -1, -1, -1, -1) + return pose_latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 720, + width: int = 1280, + num_frames: int = 77, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames + 1, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + def pad_video_frames(self, frames: List[Any], num_target_frames: int) -> List[Any]: + """ + Pads an array-like video `frames` to `num_target_frames` using a "reflect"-like strategy. The frame dimension + is assumed to be the first dimension. In the 1D case, we can visualize this strategy as follows: + + pad_video_frames([1, 2, 3, 4, 5], 10) -> [1, 2, 3, 4, 5, 4, 3, 2, 1, 2] + """ + idx = 0 + flip = False + target_frames = [] + while len(target_frames) < num_target_frames: + target_frames.append(deepcopy(frames[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == len(frames) - 1: + flip = not flip + + return target_frames + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + pose_video: List[PIL.Image.Image], + face_video: List[PIL.Image.Image], + background_video: Optional[List[PIL.Image.Image]] = None, + mask_video: Optional[List[PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + segment_frame_length: int = 77, + num_inference_steps: int = 20, + mode: str = "animate", + prev_segment_conditioning_frames: int = 1, + motion_encode_batch_size: Optional[int] = None, + guidance_scale: float = 1.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input character image to condition the generation on. Must be an image, a list of images or a + `torch.Tensor`. + pose_video (`List[PIL.Image.Image]`): + The input pose video to condition the generation on. Must be a list of PIL images. + face_video (`List[PIL.Image.Image]`): + The input face video to condition the generation on. Must be a list of PIL images. + background_video (`List[PIL.Image.Image]`, *optional*): + When mode is `"replace"`, the input background video to condition the generation on. Must be a list of + PIL images. + mask_video (`List[PIL.Image.Image]`, *optional*): + When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of PIL + images. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + mode (`str`, defaults to `"animation"`): + The mode of the generation. Choose between `"animate"` and `"replace"`. + prev_segment_conditioning_frames (`int`, defaults to `1`): + The number of frames from the previous video segment to be used for temporal guidance. Recommended to + be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer. + motion_encode_batch_size (`int`, *optional*): + The batch size for batched encoding of the face video via the motion encoder. This allows trading off + inference speed for lower memory usage by setting a smaller batch size. Will default to + `self.transformer.config.motion_encoder_batch_size` if not set. + height (`int`, defaults to `720`): + The height of the generated video. + width (`int`, defaults to `1280`): + The width of the generated video. + segment_frame_length (`int`, defaults to `77`): + The number of frames in each generated video segment. The total frames of video generated will be equal + to the number of frames in `pose_video`; we will generate the video in segments until we have hit this + length. In general, should be 4N + 1, where N is a non-negative integer. + num_inference_steps (`int`, defaults to `20`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `1.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan + Animate inference. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + mode, + prev_segment_conditioning_frames, + ) + + if segment_frame_length % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`segment_frame_length - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the" + f" nearest number." + ) + segment_frame_length = ( + segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + ) + segment_frame_length = max(segment_frame_length, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # As we generate in segments of `segment_frame_length`, set the target frame length to be the least multiple + # of the effective segment length greater than or equal to the length of `pose_video`. + cond_video_frames = len(pose_video) + effective_segment_length = segment_frame_length - prev_segment_conditioning_frames + last_segment_frames = (cond_video_frames - prev_segment_conditioning_frames) % effective_segment_length + if last_segment_frames == 0: + num_padding_frames = 0 + else: + num_padding_frames = effective_segment_length - last_segment_frames + num_target_frames = cond_video_frames + num_padding_frames + num_segments = num_target_frames // effective_segment_length + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Preprocess and encode the reference (character) image + image_height, image_width = self.video_processor.get_default_height_width(image) + if image_height != height or image_width != width: + logger.warning(f"Reshaping reference image from ({image_width}, {image_height}) to ({width}, {height})") + image_pixels = self.vae_image_processor.preprocess(image, height=height, width=width, resize_mode="fill").to( + device, dtype=torch.float32 + ) + + # Get CLIP features from the reference image + if image_embeds is None: + image_embeds = self.encode_image(image, device) + image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 5. Encode conditioning videos (pose, face) + pose_video = self.pad_video_frames(pose_video, num_target_frames) + face_video = self.pad_video_frames(face_video, num_target_frames) + + # TODO: also support np.ndarray input (e.g. from decord like the original implementation?) + pose_video_width, pose_video_height = pose_video[0].size + if pose_video_height != height or pose_video_width != width: + logger.warning( + f"Reshaping pose video from ({pose_video_width}, {pose_video_height}) to ({width}, {height})" + ) + pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + face_video_width, face_video_height = face_video[0].size + expected_face_size = self.transformer.config.motion_encoder_size + if face_video_width != expected_face_size or face_video_height != expected_face_size: + logger.warning( + f"Reshaping face video from ({face_video_width}, {face_video_height}) to ({expected_face_size}," + f" {expected_face_size})" + ) + face_video = self.video_processor.preprocess_video( + face_video, height=expected_face_size, width=expected_face_size + ).to(device, dtype=torch.float32) + + if mode == "replace": + background_video = self.pad_video_frames(background_video, num_target_frames) + mask_video = self.pad_video_frames(mask_video, num_target_frames) + + background_video = self.video_processor.preprocess_video(background_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + mask_video = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables which stay constant for all inference segments + num_channels_latents = self.vae.config.z_dim + + # Get VAE-encoded latents of the reference (character) image + reference_image_latents = self.prepare_reference_image_latents( + image_pixels, batch_size * num_videos_per_prompt, generator=generator, device=device + ) + + # 8. Loop over video inference segments + start = 0 + end = segment_frame_length # Data space frames, not latent frames + all_out_frames = [] + out_frames = None + + for _ in range(num_segments): + assert start + prev_segment_conditioning_frames < cond_video_frames + + # Sample noisy latents from prior for the current inference segment + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=segment_frame_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents if start == 0 else None, # Only use pre-calculated latents for first segment + ) + + pose_video_segment = pose_video[:, :, start:end] + face_video_segment = face_video[:, :, start:end] + + face_video_segment = face_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1) + face_video_segment = face_video_segment.to(dtype=transformer_dtype) + + if start > 0: + prev_segment_cond_video = out_frames[:, :, -prev_segment_conditioning_frames:].clone().detach() + else: + prev_segment_cond_video = None + + if mode == "replace": + background_video_segment = background_video[:, :, start:end] + mask_video_segment = mask_video[:, :, start:end] + + background_video_segment = background_video_segment.expand( + batch_size * num_videos_per_prompt, -1, -1, -1, -1 + ) + mask_video_segment = mask_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1) + else: + background_video_segment = None + mask_video_segment = None + + pose_latents = self.prepare_pose_latents( + pose_video_segment, batch_size * num_videos_per_prompt, generator=generator, device=device + ) + pose_latents = pose_latents.to(dtype=transformer_dtype) + + prev_segment_cond_latents = self.prepare_prev_segment_cond_latents( + prev_segment_cond_video, + background_video=background_video_segment, + mask_video=mask_video_segment, + batch_size=batch_size * num_videos_per_prompt, + segment_frame_length=segment_frame_length, + start_frame=start, + height=height, + width=width, + prev_segment_cond_frames=prev_segment_conditioning_frames, + task=mode, + generator=generator, + device=device, + ) + + # Concatenate the reference latents in the frame dimension + reference_latents = torch.cat([reference_image_latents, prev_segment_cond_latents], dim=2) + + # 8.1 Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Concatenate the reference image + prev segment conditioning in the channel dim + latent_model_input = torch.cat([latents, reference_latents], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + # Blank out face for unconditional guidance (set all pixels to -1) + face_pixel_values_uncond = face_video_segment * 0 - 1 + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_pixel_values_uncond, + motion_encode_batch_size=motion_encode_batch_size, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = latents.to(self.vae.dtype) + # Destandardize latents in preparation for Wan VAE decoding + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_recip_std + latents_mean + # Skip the first latent frame (used for conditioning) + out_frames = self.vae.decode(latents[:, :, 1:], return_dict=False)[0] + + if start > 0: + out_frames = out_frames[:, :, prev_segment_conditioning_frames:] + all_out_frames.append(out_frames) + + start += effective_segment_length + end += effective_segment_length + + # Reset scheduler timesteps / state for next denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + self._current_timestep = None + assert start + prev_segment_conditioning_frames >= cond_video_frames + + if not output_type == "latent": + video = torch.cat(all_out_frames, dim=2)[:, :, :cond_video_frames] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 81eb2569e303..f56a8b932505 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1623,6 +1623,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class WanAnimateTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class WanTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 19f6c0f58440..9e32232133f3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -3512,6 +3512,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class WanAnimatePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class WanImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py new file mode 100644 index 000000000000..5d571b8c2e7d --- /dev/null +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -0,0 +1,126 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import WanAnimateTransformer3DModel + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = WanAnimateTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + clip_seq_len = 12 + clip_dim = 16 + + inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model + face_height = 16 # Should be square and match `motion_encoder_size` below + face_width = 16 + + hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device) + pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to( + torch_device + ) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_image": clip_ref_features, + "pose_hidden_states": pose_latents, + "face_pixel_values": face_pixel_values, + } + + @property + def input_shape(self): + return (12, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + # Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to + # contain the vast majority of the parameters in the test model + channel_sizes = {"4": 16, "8": 16, "16": 16} + + init_dict = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 12, # 2 * C + 4 = 2 * 4 + 4 = 12 + "latent_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "image_dim": 16, + "rope_max_seq_len": 32, + "motion_encoder_channel_sizes": channel_sizes, # Start of Wan Animate-specific config + "motion_encoder_size": 16, # Ensures that there will be 2 motion encoder resblocks + "motion_style_dim": 8, + "motion_dim": 4, + "motion_encoder_dim": 16, + "face_encoder_hidden_dim": 16, + "face_encoder_num_heads": 2, + "inject_face_latents_blocks": 2, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"WanAnimateTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + # Override test_output because the transformer output is expected to have less channels than the main transformer + # input. + def test_output(self): + expected_output_shape = (1, 4, 21, 16, 16) + super().test_output(expected_output_shape=expected_output_shape) + + +class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = WanAnimateTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/wan/test_wan_animate.py b/tests/pipelines/wan/test_wan_animate.py new file mode 100644 index 000000000000..d6d1b09f3620 --- /dev/null +++ b/tests/pipelines/wan/test_wan_animate.py @@ -0,0 +1,239 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + AutoTokenizer, + CLIPImageProcessor, + CLIPVisionConfig, + CLIPVisionModelWithProjection, + T5EncoderModel, +) + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanAnimatePipeline, + WanAnimateTransformer3DModel, +) + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class WanAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanAnimatePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + channel_sizes = {"4": 16, "8": 16, "16": 16} + transformer = WanAnimateTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + latent_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + image_dim=4, + rope_max_seq_len=32, + motion_encoder_channel_sizes=channel_sizes, + motion_encoder_size=16, + motion_style_dim=8, + motion_dim=4, + motion_encoder_dim=16, + face_encoder_hidden_dim=16, + face_encoder_num_heads=2, + inject_face_latents_blocks=2, + ) + + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=4, + projection_dim=4, + num_hidden_layers=2, + num_attention_heads=2, + image_size=4, + intermediate_size=16, + patch_size=1, + ) + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + torch.manual_seed(0) + image_processor = CLIPImageProcessor(crop_size=4, size=4) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "image_encoder": image_encoder, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + num_frames = 17 + height = 16 + width = 16 + face_height = 16 + face_width = 16 + + image = Image.new("RGB", (height, width)) + pose_video = [Image.new("RGB", (height, width))] * num_frames + face_video = [Image.new("RGB", (face_height, face_width))] * num_frames + + inputs = { + "image": image, + "pose_video": pose_video, + "face_video": face_video, + "prompt": "dance monkey", + "negative_prompt": "negative", + "height": height, + "width": width, + "segment_frame_length": 77, # TODO: can we set this to num_frames? + "num_inference_steps": 2, + "mode": "animate", + "prev_segment_conditioning_frames": 1, + "generator": generator, + "guidance_scale": 1.0, + "output_type": "pt", + "max_sequence_length": 16, + } + return inputs + + def test_inference(self): + """Test basic inference in animation mode.""" + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + expected_video = torch.randn(17, 3, 16, 16) + max_diff = np.abs(video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_inference_replacement(self): + """Test the pipeline in replacement mode with background and mask videos.""" + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["mode"] = "replace" + num_frames = 17 + height = 16 + width = 16 + inputs["background_video"] = [Image.new("RGB", (height, width))] * num_frames + inputs["mask_video"] = [Image.new("L", (height, width))] * num_frames + + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip( + "Setting the Wan Animate latents to zero at the last denoising step does not guarantee that the output will be" + " zero. I believe this is because the latents are further processed in the outer loop where we loop over" + " inference segments." + ) + def test_callback_inputs(self): + pass + + +@slow +@require_torch_accelerator +class WanAnimatePipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + @unittest.skip("TODO: test needs to be implemented") + def test_wan_animate(self): + pass diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 0f4fd408a7c1..b42764be10d6 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -16,6 +16,7 @@ HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, + WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) @@ -721,6 +722,33 @@ def get_dummy_inputs(self): } +class WanAnimateGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanAnimateTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states": torch.randn( + (1, 96, 2, 64, 64), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states_scale": torch.randn( + (8,), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + @require_torch_version_greater("2.7.1") class GGUFCompileTests(QuantCompileTests, unittest.TestCase): torch_dtype = torch.bfloat16 From 7a001c3ee2d51ecc69987da050468a318414afa3 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Thu, 13 Nov 2025 14:27:12 +0800 Subject: [PATCH 06/35] adjust unit tests for `test_save_load_float16` (#12500) * adjust unit tests for wan pipeline Signed-off-by: Liu, Kaixuan * update code Signed-off-by: Liu, Kaixuan * avoid adjusting common `get_dummy_components` API Signed-off-by: Liu, Kaixuan * use `form_pretrained` to `transformer` and `transformer_2` Signed-off-by: Liu, Kaixuan * update code Signed-off-by: Liu, Kaixuan * update Signed-off-by: Liu, Kaixuan --------- Signed-off-by: Liu, Kaixuan Co-authored-by: Sayak Paul Co-authored-by: Dhruv Nair --- tests/pipelines/test_pipelines_common.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2af4ad0314c3..e2bbce7b0ead 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1422,7 +1422,18 @@ def test_float16_inference(self, expected_max_diff=5e-2): def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): - if hasattr(module, "half"): + # Account for components with _keep_in_fp32_modules + if hasattr(module, "_keep_in_fp32_modules") and module._keep_in_fp32_modules is not None: + for name, param in module.named_parameters(): + if any( + module_to_keep_in_fp32 in name.split(".") + for module_to_keep_in_fp32 in module._keep_in_fp32_modules + ): + param.data = param.data.to(torch_device).to(torch.float32) + else: + param.data = param.data.to(torch_device).to(torch.float16) + + elif hasattr(module, "half"): components[name] = module.to(torch_device).half() pipe = self.pipeline_class(**components) From cd3bbe2910666880307b84729176203f5785ff7e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 13 Nov 2025 12:56:22 +0530 Subject: [PATCH 07/35] skip autoencoderdl layerwise casting memory (#12647) --- tests/models/autoencoders/test_models_autoencoder_dc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py index d34001e7b903..b1b5531d0134 100644 --- a/tests/models/autoencoders/test_models_autoencoder_dc.py +++ b/tests/models/autoencoders/test_models_autoencoder_dc.py @@ -82,3 +82,7 @@ def prepare_init_args_and_inputs_for_common(self): @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") def test_layerwise_casting_inference(self): super().test_layerwise_casting_inference() + + @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") + def test_layerwise_casting_memory(self): + super().test_layerwise_casting_memory() From 6a2309b98d415d4ca1da69f59283507fe3eb1d73 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Thu, 13 Nov 2025 08:42:31 -0800 Subject: [PATCH 08/35] [utils] Update check_doc_toc (#12642) update Co-authored-by: Sayak Paul --- docs/source/en/_toctree.yml | 8 ++++---- utils/check_doc_toc.py | 22 ++++++++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 77eacba664a2..5e9299aece28 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -450,6 +450,8 @@ - sections: - local: api/pipelines/overview title: Overview + - local: api/pipelines/auto_pipeline + title: AutoPipeline - sections: - local: api/pipelines/audioldm title: AudioLDM @@ -462,8 +464,6 @@ - local: api/pipelines/stable_audio title: Stable Audio title: Audio - - local: api/pipelines/auto_pipeline - title: AutoPipeline - sections: - local: api/pipelines/amused title: aMUSEd @@ -527,6 +527,8 @@ title: HiDream-I1 - local: api/pipelines/hunyuandit title: Hunyuan-DiT + - local: api/pipelines/hunyuanimage21 + title: HunyuanImage2.1 - local: api/pipelines/pix2pix title: InstructPix2Pix - local: api/pipelines/kandinsky @@ -640,8 +642,6 @@ title: ConsisID - local: api/pipelines/framepack title: Framepack - - local: api/pipelines/hunyuanimage21 - title: HunyuanImage2.1 - local: api/pipelines/hunyuan_video title: HunyuanVideo - local: api/pipelines/i2vgenxl diff --git a/utils/check_doc_toc.py b/utils/check_doc_toc.py index 0dd02cde86c1..050b093991e6 100644 --- a/utils/check_doc_toc.py +++ b/utils/check_doc_toc.py @@ -21,20 +21,23 @@ PATH_TO_TOC = "docs/source/en/_toctree.yml" +# Titles that should maintain their position and not be sorted alphabetically +FIXED_POSITION_TITLES = {"overview", "autopipeline"} + def clean_doc_toc(doc_list): """ Cleans the table of content of the model documentation by removing duplicates and sorting models alphabetically. """ counts = defaultdict(int) - overview_doc = [] + fixed_position_docs = [] new_doc_list = [] for doc in doc_list: if "local" in doc: counts[doc["local"]] += 1 - if doc["title"].lower() == "overview": - overview_doc.append({"local": doc["local"], "title": doc["title"]}) + if doc["title"].lower() in FIXED_POSITION_TITLES: + fixed_position_docs.append({"local": doc["local"], "title": doc["title"]}) else: new_doc_list.append(doc) @@ -57,14 +60,13 @@ def clean_doc_toc(doc_list): new_doc.extend([doc for doc in doc_list if "local" not in counts or counts[doc["local"]] == 1]) new_doc = sorted(new_doc, key=lambda s: s["title"].lower()) - # "overview" gets special treatment and is always first - if len(overview_doc) > 1: - raise ValueError("{doc_list} has two 'overview' docs which is not allowed.") - - overview_doc.extend(new_doc) + # Fixed-position titles maintain their original order + result = [] + for doc in fixed_position_docs: + result.append(doc) - # Sort - return overview_doc + result.extend(new_doc) + return result def check_scheduler_doc(overwrite=False): From 40de88af8c8ef6ecd69f99dabeeb07f8362fcf87 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Thu, 13 Nov 2025 08:43:24 -0800 Subject: [PATCH 09/35] [docs] AutoModel (#12644) * automodel * fix --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/models/auto_model.md | 10 +---- docs/source/en/using-diffusers/automodel.md | 46 +++++++++++++++++++++ 3 files changed, 49 insertions(+), 9 deletions(-) create mode 100644 docs/source/en/using-diffusers/automodel.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 5e9299aece28..e3b9f99927e8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -22,6 +22,8 @@ title: Reproducibility - local: using-diffusers/schedulers title: Schedulers + - local: using-diffusers/automodel + title: AutoModel - local: using-diffusers/other-formats title: Model formats - local: using-diffusers/push_to_hub diff --git a/docs/source/en/api/models/auto_model.md b/docs/source/en/api/models/auto_model.md index 376dd12d12c4..aee9b5dbe50c 100644 --- a/docs/source/en/api/models/auto_model.md +++ b/docs/source/en/api/models/auto_model.md @@ -12,15 +12,7 @@ specific language governing permissions and limitations under the License. # AutoModel -The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file. - -```python -from diffusers import AutoModel, AutoPipelineForText2Image - -unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet") -pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet) -``` - +[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file. ## AutoModel diff --git a/docs/source/en/using-diffusers/automodel.md b/docs/source/en/using-diffusers/automodel.md new file mode 100644 index 000000000000..957cbd17e3f7 --- /dev/null +++ b/docs/source/en/using-diffusers/automodel.md @@ -0,0 +1,46 @@ + + +# AutoModel + +The [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries. + +The example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from. + +```py +import torch +from diffusers import AutoModel, DiffusionPipeline + +transformer = AutoModel.from_pretrained( + "Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda" +) + +text_encoder = AutoModel.from_pretrained( + "Qwen/Qwen-Image", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda" +) +``` + +[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models. + +```py +import torch +from diffusers import AutoModel + +transformer = AutoModel.from_pretrained( + "custom/custom-transformer-model", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda" +) +``` + +If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading). + +> [!NOTE] +> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide. \ No newline at end of file From 6fe4a6ff8ece2d7e734a62a9e439fb00eb03dc69 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Thu, 13 Nov 2025 23:45:58 +0100 Subject: [PATCH 10/35] Improve docstrings and type hints in scheduling_ddim.py (#12622) * Improve docstrings and type hints in scheduling_ddim.py - Add complete type hints for all function parameters - Enhance docstrings to follow project conventions - Add missing parameter descriptions Fixes #9567 * Enhance docstrings and type hints in scheduling_ddim.py - Update parameter types and descriptions for clarity - Improve explanations in method docstrings to align with project standards - Add optional annotations for parameters where applicable * Refine type hints and docstrings in scheduling_ddim.py - Update parameter types to use Literal for specific string options - Enhance docstring descriptions for clarity and consistency - Ensure all parameters have appropriate type annotations and defaults * Apply review feedback on scheduling_ddim.py - Replace "prevent singularities" with "avoid numerical instability" for better clarity - Add backticks around `alpha_bar` variable name for consistent formatting - Convert Imagen Video paper URLs to Hugging Face papers references * Propagate changes using 'make fix-copies' * Add missing Literal --- src/diffusers/schedulers/scheduling_ddim.py | 76 ++++++++++++------- .../schedulers/scheduling_ddim_inverse.py | 1 - .../schedulers/scheduling_ddim_parallel.py | 14 ++-- src/diffusers/schedulers/scheduling_ddpm.py | 1 - .../schedulers/scheduling_ddpm_parallel.py | 1 - .../scheduling_dpmsolver_multistep.py | 1 - .../scheduling_euler_ancestral_discrete.py | 1 - .../schedulers/scheduling_euler_discrete.py | 1 - src/diffusers/schedulers/scheduling_lcm.py | 1 - src/diffusers/schedulers/scheduling_tcd.py | 19 ++++- .../schedulers/scheduling_unipc_multistep.py | 1 - 11 files changed, 77 insertions(+), 40 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 5ee0d084f060..c63f1f4c1675 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -92,11 +92,10 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. @@ -143,9 +142,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one + of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. clip_sample (`bool`, defaults to `True`): @@ -158,10 +157,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): otherwise it uses the alpha value at step 0. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). + prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`): + Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion + process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen + Video](https://huggingface.co/papers/2210.02303) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. @@ -169,9 +168,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. - timestep_spacing (`str`, defaults to `"leading"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`): + The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to + Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) for more information. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to @@ -187,17 +187,17 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: @@ -250,7 +250,25 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None """ return sample - def _get_variance(self, timestep, prev_timestep): + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + """ + Computes the variance of the noise added at a given diffusion step. + + For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM + literature: + var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively. + + Args: + timestep (`int`): + The current timestep in the diffusion process. + prev_timestep (`int`): + The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`. + + Returns: + `torch.Tensor`: + The variance for the current timestep. + """ alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t @@ -294,13 +312,18 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + device (`Union[str, torch.device]`, *optional*): + The device to use for the timesteps. + + Raises: + ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`. """ if num_inference_steps > self.config.num_train_timesteps: @@ -346,7 +369,7 @@ def step( sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: @@ -357,20 +380,21 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - eta (`float`): - The weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`, defaults to `False`): + eta (`float`, *optional*, defaults to 0.0): + The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic) + and 1 corresponds to DDPM (fully stochastic). + use_clipped_model_output (`bool`, *optional*, defaults to `False`): If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` has no effect. generator (`torch.Generator`, *optional*): - A random number generator. - variance_noise (`torch.Tensor`): + A random number generator for reproducible sampling. + variance_noise (`torch.Tensor`, *optional*): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): @@ -517,5 +541,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 49dba840d089..d13ac606805c 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -95,7 +95,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 7c3f03a8dbe1..deffdb4ff7d3 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. @@ -194,17 +193,17 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: @@ -324,6 +323,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + device (`Union[str, torch.device]`, *optional*): + The device to use for the timesteps. + + Raises: + ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`. """ if num_inference_steps > self.config.num_train_timesteps: diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 0fab6d910a82..b59fae066495 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -94,7 +94,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index ec741f9ecb7d..c78bfe290f53 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -96,7 +96,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 8b523cd13f1f..0560a030321d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 9cdaa2c5e101..38ad401edc49 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index f58d918dbfbe..59199bf71013 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -100,7 +100,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index cd7a29fe675f..8a0fd480505c 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -99,7 +99,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 3fd5c341eca9..ce7d1d5316b4 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -98,7 +98,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. @@ -316,6 +315,24 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance def _get_variance(self, timestep, prev_timestep): + """ + Computes the variance of the noise added at a given diffusion step. + + For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM + literature: + var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively. + + Args: + timestep (`int`): + The current timestep in the diffusion process. + prev_timestep (`int`): + The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`. + + Returns: + `torch.Tensor`: + The variance for the current timestep. + """ alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 162a34bd2774..a596fef24559 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. From 3c1ca869d78c41ed144bdfc0a6574d229c537144 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Thu, 13 Nov 2025 23:46:23 +0100 Subject: [PATCH 11/35] Improve docstrings and type hints in scheduling_ddpm.py (#12651) * Enhance type hints and docstrings in scheduling_ddpm.py - Added type hints for function parameters and return types across the DDPMScheduler class and related functions. - Improved docstrings for clarity, including detailed descriptions of parameters and return values. - Updated the alpha_transform_type and beta_schedule parameters to use Literal types for better type safety. - Refined the _get_variance and previous_timestep methods with comprehensive documentation. * Refactor docstrings and type hints in scheduling_ddpm.py - Cleaned up whitespace in the rescale_zero_terminal_snr function. - Enhanced the variance_type parameter in the DDPMScheduler class with improved formatting for better readability. - Updated the docstring for the compute_variance method to maintain consistency and clarity in parameter descriptions and return values. * Apply `make fix-copies` * Refactor type hints across multiple scheduler files - Updated type hints to include `Literal` for improved type safety in various scheduling files. - Ensured consistency in type hinting for parameters and return types across the affected modules. - This change enhances code clarity and maintainability. --- .../scheduling_consistency_decoder.py | 25 +-- src/diffusers/schedulers/scheduling_ddim.py | 64 ++++++-- .../schedulers/scheduling_ddim_cogvideox.py | 56 +++++-- .../schedulers/scheduling_ddim_inverse.py | 25 +-- .../schedulers/scheduling_ddim_parallel.py | 64 ++++++-- src/diffusers/schedulers/scheduling_ddpm.py | 154 +++++++++++++----- .../schedulers/scheduling_ddpm_parallel.py | 113 +++++++++++-- .../schedulers/scheduling_deis_multistep.py | 35 ++-- .../schedulers/scheduling_dpm_cogvideox.py | 56 +++++-- .../scheduling_dpmsolver_multistep.py | 35 ++-- .../scheduling_dpmsolver_multistep_inverse.py | 35 ++-- .../schedulers/scheduling_dpmsolver_sde.py | 25 +-- .../scheduling_dpmsolver_singlestep.py | 35 ++-- .../scheduling_edm_dpmsolver_multistep.py | 10 ++ .../scheduling_euler_ancestral_discrete.py | 25 +-- .../schedulers/scheduling_euler_discrete.py | 25 +-- .../schedulers/scheduling_heun_discrete.py | 25 +-- .../scheduling_k_dpm_2_ancestral_discrete.py | 25 +-- .../schedulers/scheduling_k_dpm_2_discrete.py | 25 +-- src/diffusers/schedulers/scheduling_lcm.py | 77 +++++++-- .../schedulers/scheduling_lms_discrete.py | 25 +-- src/diffusers/schedulers/scheduling_pndm.py | 41 +++-- .../schedulers/scheduling_repaint.py | 25 +-- .../schedulers/scheduling_sasolver.py | 51 ++++-- src/diffusers/schedulers/scheduling_tcd.py | 77 +++++++-- src/diffusers/schedulers/scheduling_unclip.py | 41 +++-- .../schedulers/scheduling_unipc_multistep.py | 35 ++-- 27 files changed, 887 insertions(+), 342 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py index d7af018b284a..767fa9157f59 100644 --- a/src/diffusers/schedulers/scheduling_consistency_decoder.py +++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import torch @@ -12,10 +12,10 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -23,16 +23,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index c63f1f4c1675..5ddc46ee4d2f 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -49,10 +49,10 @@ class DDIMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,16 +60,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -281,6 +282,8 @@ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -288,6 +291,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -501,6 +512,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -523,6 +550,21 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Compute the velocity prediction from the sample and noise according to the velocity formula. + + Args: + sample (`torch.Tensor`): + The input sample. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps for velocity computation. + + Returns: + `torch.Tensor`: + The computed velocity. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index c19efdc7834d..acb5a5f3e522 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -18,7 +18,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -49,10 +49,10 @@ class DDIMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,16 +60,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -408,6 +409,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -430,6 +447,21 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Compute the velocity prediction from the sample and noise according to the velocity formula. + + Args: + sample (`torch.Tensor`): + The input sample. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps for velocity computation. + + Returns: + `torch.Tensor`: + The computed velocity. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index d13ac606805c..bed424e320be 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -16,7 +16,7 @@ # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -47,10 +47,10 @@ class DDIMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -58,16 +58,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index deffdb4ff7d3..1432d835aea2 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -49,10 +49,10 @@ class DDIMParallelSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,16 +60,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -284,6 +285,8 @@ def _batch_get_variance(self, t, prev_t): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -291,6 +294,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -606,6 +617,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -628,6 +655,21 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Compute the velocity prediction from the sample and noise according to the velocity formula. + + Args: + sample (`torch.Tensor`): + The input sample. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps for velocity computation. + + Returns: + `torch.Tensor`: + The computed velocity. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index b59fae066495..5ccf4adaebbc 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -46,10 +46,10 @@ class DDPMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -57,16 +57,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -90,7 +91,7 @@ def alpha_bar_fn(t): # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -133,39 +134,37 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - num_train_timesteps (`int`, defaults to 1000): + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to `0.0001`): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to `0.02`): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`. + beta_schedule (`"linear"`, `"scaled_linear"`, `"squaredcos_cap_v2"`, or `"sigmoid"`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. trained_betas (`np.ndarray`, *optional*): An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`. - variance_type (`str`, defaults to `"fixed_small"`): - Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, - `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, defaults to `"fixed_small"`): + Clip the variance when adding noise to the denoised sample. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. - clip_sample_range (`float`, defaults to 1.0): + clip_sample_range (`float`, defaults to `1.0`): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. - prediction_type (`str`, defaults to `epsilon`, *optional*): + prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): + dynamic_thresholding_ratio (`float`, defaults to `0.995`): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): + sample_max_value (`float`, defaults to `1.0`): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. - timestep_spacing (`str`, defaults to `"leading"`): + timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): + steps_offset (`int`, defaults to `0`): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and @@ -182,16 +181,18 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - variance_type: str = "fixed_small", + variance_type: Literal[ + "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range" + ] = "fixed_small", clip_sample: bool = True, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): @@ -321,7 +322,31 @@ def set_timesteps( self.timesteps = torch.from_numpy(timesteps).to(device) - def _get_variance(self, t, predicted_variance=None, variance_type=None): + def _get_variance( + self, + t: int, + predicted_variance: Optional[torch.Tensor] = None, + variance_type: Optional[ + Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] + ] = None, + ) -> torch.Tensor: + """ + Compute the variance for a given timestep according to the specified variance type. + + Args: + t (`int`): + The current timestep. + predicted_variance (`torch.Tensor`, *optional*): + The predicted variance from the model. Used only when `variance_type` is `"learned"` or + `"learned_range"`. + variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, *optional*): + The type of variance to compute. If `None`, uses the variance type specified in the scheduler + configuration. + + Returns: + `torch.Tensor`: + The computed variance. + """ prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] @@ -363,6 +388,8 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -370,6 +397,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -399,7 +434,7 @@ def step( model_output: torch.Tensor, timestep: int, sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: """ @@ -409,20 +444,19 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. - return_dict (`bool`, *optional*, defaults to `True`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. - """ t = timestep @@ -503,6 +537,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -524,6 +574,21 @@ def add_noise( return noisy_samples def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Compute the velocity prediction from the sample and noise according to the velocity formula. + + Args: + sample (`torch.Tensor`): + The input sample. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps for velocity computation. + + Returns: + `torch.Tensor`: + The computed velocity. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) @@ -542,10 +607,21 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps - def previous_timestep(self, timestep): + def previous_timestep(self, timestep: int) -> int: + """ + Compute the previous timestep in the diffusion chain. + + Args: + timestep (`int`): + The current timestep. + + Returns: + `int`: + The previous timestep. + """ if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index c78bfe290f53..8740f14c66b4 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -48,10 +48,10 @@ class DDPMParallelSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -59,16 +59,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -190,16 +191,18 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - variance_type: str = "fixed_small", + variance_type: Literal[ + "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range" + ] = "fixed_small", clip_sample: bool = True, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): @@ -332,7 +335,31 @@ def set_timesteps( self.timesteps = torch.from_numpy(timesteps).to(device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance - def _get_variance(self, t, predicted_variance=None, variance_type=None): + def _get_variance( + self, + t: int, + predicted_variance: Optional[torch.Tensor] = None, + variance_type: Optional[ + Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] + ] = None, + ) -> torch.Tensor: + """ + Compute the variance for a given timestep according to the specified variance type. + + Args: + t (`int`): + The current timestep. + predicted_variance (`torch.Tensor`, *optional*): + The predicted variance from the model. Used only when `variance_type` is `"learned"` or + `"learned_range"`. + variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, *optional*): + The type of variance to compute. If `None`, uses the variance type specified in the scheduler + configuration. + + Returns: + `torch.Tensor`: + The computed variance. + """ prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] @@ -375,6 +402,8 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -382,6 +411,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -592,6 +629,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -614,6 +667,21 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Compute the velocity prediction from the sample and noise according to the velocity formula. + + Args: + sample (`torch.Tensor`): + The input sample. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps for velocity computation. + + Returns: + `torch.Tensor`: + The computed velocity. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) @@ -637,6 +705,17 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): + """ + Compute the previous timestep in the diffusion chain. + + Args: + timestep (`int`): + The current timestep. + + Returns: + `int`: + The previous timestep. + """ if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 7d8685ba10c3..15d8a20e33b8 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -16,7 +16,7 @@ # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -32,10 +32,10 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -43,16 +43,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -320,6 +321,8 @@ def set_timesteps( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -327,6 +330,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index f7b63720e107..c5d79b5fe54a 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -18,7 +18,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -50,10 +50,10 @@ class DDIMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -61,16 +61,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -445,6 +446,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -467,6 +484,21 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Compute the velocity prediction from the sample and noise according to the velocity formula. + + Args: + sample (`torch.Tensor`): + The input sample. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps for velocity computation. + + Returns: + `torch.Tensor`: + The computed velocity. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 0560a030321d..b1f218a5eb06 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -32,10 +32,10 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -43,16 +43,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -459,6 +460,8 @@ def set_timesteps( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -466,6 +469,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index f1a1ac3d8216..476d2fc10568 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -32,10 +32,10 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -43,16 +43,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -332,6 +333,8 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -339,6 +342,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index eeb06773d977..2b02c2fd5e57 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -115,10 +115,10 @@ def __call__(self, sigma, sigma_next): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -126,16 +126,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 1ae824973034..e7fde2c2ba0d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -34,10 +34,10 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -45,16 +45,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -410,6 +411,8 @@ def set_timesteps( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -417,6 +420,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index e9ba695e1f39..f748c6c834a3 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -299,6 +299,8 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -306,6 +308,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 38ad401edc49..b2741c586be2 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -49,10 +49,10 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,16 +60,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 59199bf71013..f88b124f04e7 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -52,10 +52,10 @@ class EulerDiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -63,16 +63,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index bd1239cfaec7..db81fc82bcf3 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -49,10 +49,10 @@ class HeunDiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,16 +60,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 6588464073a1..48cc01e6aac7 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -50,10 +50,10 @@ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -61,16 +61,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 9b4cd4e204d6..aaf6a48b57be 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -49,10 +49,10 @@ class KDPM2DiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,16 +60,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index 8a0fd480505c..36587537ec1b 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -51,10 +51,10 @@ class LCMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -62,16 +62,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -314,6 +315,8 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -321,6 +324,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -596,6 +607,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -618,6 +645,21 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Compute the velocity prediction from the sample and noise according to the velocity formula. + + Args: + sample (`torch.Tensor`): + The input sample. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps for velocity computation. + + Returns: + `torch.Tensor`: + The computed velocity. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) @@ -641,6 +683,17 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): + """ + Compute the previous timestep in the diffusion chain. + + Args: + timestep (`int`): + The current timestep. + + Returns: + `int`: + The previous timestep. + """ if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index c2450204aa8f..6fa9c2f7fbcf 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -14,7 +14,7 @@ import math import warnings from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import scipy.stats @@ -47,10 +47,10 @@ class LMSDiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -58,16 +58,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index c07621179e2b..aded6c224671 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -26,10 +26,10 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -37,16 +37,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -452,6 +453,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index 6530c5af9e5b..a2eaf8eb3abd 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import numpy as np import torch @@ -45,10 +45,10 @@ class RePaintSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -56,16 +56,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 2979ce193a36..30a3eb294a04 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -16,7 +16,7 @@ # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -33,10 +33,10 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -44,16 +44,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -342,6 +343,8 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -349,6 +352,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -1193,6 +1204,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index ce7d1d5316b4..101b1569a145 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -50,10 +50,10 @@ class TCDSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -61,16 +61,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -345,6 +346,8 @@ def _get_variance(self, timestep, prev_timestep): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -352,6 +355,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -651,6 +662,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -673,6 +700,21 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Compute the velocity prediction from the sample and noise according to the velocity formula. + + Args: + sample (`torch.Tensor`): + The input sample. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps for velocity computation. + + Returns: + `torch.Tensor`: + The computed velocity. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) @@ -696,6 +738,17 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): + """ + Compute the previous timestep in the diffusion chain. + + Args: + timestep (`int`): + The current timestep. + + Returns: + `int`: + The previous timestep. + """ if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index d78efabfbc57..5a978dec649b 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import numpy as np import torch @@ -46,10 +46,10 @@ class UnCLIPSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -57,16 +57,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -334,6 +335,22 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise magnitude at each timestep (this is the forward + diffusion process). + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps indicating the noise level for each sample. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index a596fef24559..d985871df109 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -16,7 +16,7 @@ # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -32,10 +32,10 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp"] = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -43,16 +43,17 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to `0.999`): + The maximum beta to use; use values lower than 1 to avoid numerical instability. + alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: + The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -431,6 +432,8 @@ def set_timesteps( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ + Apply dynamic thresholding to the predicted sample. + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing @@ -438,6 +441,14 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: photorealism as well as better image-text alignment, especially when using very large guidance weights." https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The predicted sample to be thresholded. + + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape From eeae0338e7ad2b3749eac0c8701ec250a1884844 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 14 Nov 2025 10:59:59 +0530 Subject: [PATCH 12/35] [Modular] Add Custom Blocks guide to doc (#12339) * update * update * Update docs/source/en/modular_diffusers/custom_blocks.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/modular_diffusers/custom_blocks.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/_toctree.yml Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/modular_diffusers/custom_blocks.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestion from @stevhliu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestion from @stevhliu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update * update * update * Apply suggestion from @stevhliu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestion from @stevhliu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update * update * update * update * update * Update docs/source/en/modular_diffusers/custom_blocks.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + .../en/modular_diffusers/custom_blocks.md | 492 ++++++++++++++++++ 2 files changed, 494 insertions(+) create mode 100644 docs/source/en/modular_diffusers/custom_blocks.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e3b9f99927e8..24420af8e490 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -121,6 +121,8 @@ title: ComponentsManager - local: modular_diffusers/guiders title: Guiders + - local: modular_diffusers/custom_blocks + title: Building Custom Blocks title: Modular Diffusers - isExpanded: false sections: diff --git a/docs/source/en/modular_diffusers/custom_blocks.md b/docs/source/en/modular_diffusers/custom_blocks.md new file mode 100644 index 000000000000..1c311582264e --- /dev/null +++ b/docs/source/en/modular_diffusers/custom_blocks.md @@ -0,0 +1,492 @@ + + + +# Building Custom Blocks + +[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block. + +> [!TIP] +> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana. + +## Project Structure + +Your custom block project should use the following structure: + +```shell +. +├── block.py +└── modular_config.json +``` + +- `block.py` contains the custom block implementation +- `modular_config.json` contains the metadata needed to load the block + +## Example: Florence 2 Inpainting Block + +In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting. + +The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub. + +```py +# Inside block.py +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + ComponentSpec, +) +from transformers import AutoProcessor, Florence2ForConditionalGeneration + + +class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): + + @property + def expected_components(self): + return [ + ComponentSpec( + name="image_annotator", + type_hint=Florence2ForConditionalGeneration, + pretrained_model_name_or_path="florence-community/Florence-2-base-ft", + ), + ComponentSpec( + name="image_annotator_processor", + type_hint=AutoProcessor, + pretrained_model_name_or_path="florence-community/Florence-2-base-ft", + ), + ] +``` + +Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations. + +```py +from typing import List, Union +from PIL import Image, ImageDraw +import torch +import numpy as np + +from diffusers.modular_pipelines import ( + PipelineState, + ModularPipelineBlocks, + InputParam, + ComponentSpec, + OutputParam, +) +from transformers import AutoProcessor, Florence2ForConditionalGeneration + + +class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): + + @property + def expected_components(self): + return [ + ComponentSpec( + name="image_annotator", + type_hint=Florence2ForConditionalGeneration, + pretrained_model_name_or_path="florence-community/Florence-2-base-ft", + ), + ComponentSpec( + name="image_annotator_processor", + type_hint=AutoProcessor, + pretrained_model_name_or_path="florence-community/Florence-2-base-ft", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "image", + type_hint=Union[Image.Image, List[Image.Image]], + required=True, + description="Image(s) to annotate", + ), + InputParam( + "annotation_task", + type_hint=Union[str, List[str]], + required=True, + default="", + description="""Annotation Task to perform on the image. + Supported Tasks: + + + + + + + + + + + """, + ), + InputParam( + "annotation_prompt", + type_hint=Union[str, List[str]], + required=True, + description="""Annotation Prompt to provide more context to the task. + Can be used to detect or segment out specific elements in the image + """, + ), + InputParam( + "annotation_output_type", + type_hint=str, + required=True, + default="mask_image", + description="""Output type from annotation predictions. Availabe options are + mask_image: + -black and white mask image for the given image based on the task type + mask_overlay: + - mask overlayed on the original image + bounding_box: + - bounding boxes drawn on the original image + """, + ), + InputParam( + "annotation_overlay", + type_hint=bool, + required=True, + default=False, + description="", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "mask_image", + type_hint=Image, + description="Inpainting Mask for input Image(s)", + ), + OutputParam( + "annotations", + type_hint=dict, + description="Annotations Predictions for input Image(s)", + ), + OutputParam( + "image", + type_hint=Image, + description="Annotated input Image(s)", + ), + ] + +``` + +Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask. + +```py +from typing import List, Union +from PIL import Image, ImageDraw +import torch +import numpy as np + +from diffusers.modular_pipelines import ( + PipelineState, + ModularPipelineBlocks, + InputParam, + ComponentSpec, + OutputParam, +) +from transformers import AutoProcessor, Florence2ForConditionalGeneration + + +class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): + + @property + def expected_components(self): + return [ + ComponentSpec( + name="image_annotator", + type_hint=Florence2ForConditionalGeneration, + pretrained_model_name_or_path="florence-community/Florence-2-base-ft", + ), + ComponentSpec( + name="image_annotator_processor", + type_hint=AutoProcessor, + pretrained_model_name_or_path="florence-community/Florence-2-base-ft", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "image", + type_hint=Union[Image.Image, List[Image.Image]], + required=True, + description="Image(s) to annotate", + ), + InputParam( + "annotation_task", + type_hint=Union[str, List[str]], + required=True, + default="", + description="""Annotation Task to perform on the image. + Supported Tasks: + + + + + + + + + + + """, + ), + InputParam( + "annotation_prompt", + type_hint=Union[str, List[str]], + required=True, + description="""Annotation Prompt to provide more context to the task. + Can be used to detect or segment out specific elements in the image + """, + ), + InputParam( + "annotation_output_type", + type_hint=str, + required=True, + default="mask_image", + description="""Output type from annotation predictions. Availabe options are + mask_image: + -black and white mask image for the given image based on the task type + mask_overlay: + - mask overlayed on the original image + bounding_box: + - bounding boxes drawn on the original image + """, + ), + InputParam( + "annotation_overlay", + type_hint=bool, + required=True, + default=False, + description="", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "mask_image", + type_hint=Image, + description="Inpainting Mask for input Image(s)", + ), + OutputParam( + "annotations", + type_hint=dict, + description="Annotations Predictions for input Image(s)", + ), + OutputParam( + "image", + type_hint=Image, + description="Annotated input Image(s)", + ), + ] + + def get_annotations(self, components, images, prompts, task): + task_prompts = [task + prompt for prompt in prompts] + + inputs = components.image_annotator_processor( + text=task_prompts, images=images, return_tensors="pt" + ).to(components.image_annotator.device, components.image_annotator.dtype) + + generated_ids = components.image_annotator.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + early_stopping=False, + do_sample=False, + num_beams=3, + ) + annotations = components.image_annotator_processor.batch_decode( + generated_ids, skip_special_tokens=False + ) + outputs = [] + for image, annotation in zip(images, annotations): + outputs.append( + components.image_annotator_processor.post_process_generation( + annotation, task=task, image_size=(image.width, image.height) + ) + ) + return outputs + + def prepare_mask(self, images, annotations, overlay=False, fill="white"): + masks = [] + for image, annotation in zip(images, annotations): + mask_image = image.copy() if overlay else Image.new("L", image.size, 0) + draw = ImageDraw.Draw(mask_image) + + for _, _annotation in annotation.items(): + if "polygons" in _annotation: + for polygon in _annotation["polygons"]: + polygon = np.array(polygon).reshape(-1, 2) + if len(polygon) < 3: + continue + polygon = polygon.reshape(-1).tolist() + draw.polygon(polygon, fill=fill) + + elif "bbox" in _annotation: + bbox = _annotation["bbox"] + draw.rectangle(bbox, fill="white") + + masks.append(mask_image) + + return masks + + def prepare_bounding_boxes(self, images, annotations): + outputs = [] + for image, annotation in zip(images, annotations): + image_copy = image.copy() + draw = ImageDraw.Draw(image_copy) + for _, _annotation in annotation.items(): + bbox = _annotation["bbox"] + label = _annotation["label"] + + draw.rectangle(bbox, outline="red", width=3) + draw.text((bbox[0], bbox[1] - 20), label, fill="red") + + outputs.append(image_copy) + + return outputs + + def prepare_inputs(self, images, prompts): + prompts = prompts or "" + + if isinstance(images, Image.Image): + images = [images] + if isinstance(prompts, str): + prompts = [prompts] + + if len(images) != len(prompts): + raise ValueError("Number of images and annotation prompts must match.") + + return images, prompts + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + images, annotation_task_prompt = self.prepare_inputs( + block_state.image, block_state.annotation_prompt + ) + task = block_state.annotation_task + fill = block_state.fill + + annotations = self.get_annotations( + components, images, annotation_task_prompt, task + ) + block_state.annotations = annotations + if block_state.annotation_output_type == "mask_image": + block_state.mask_image = self.prepare_mask(images, annotations) + else: + block_state.mask_image = None + + if block_state.annotation_output_type == "mask_overlay": + block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill) + + elif block_state.annotation_output_type == "bounding_box": + block_state.image = self.prepare_bounding_boxes(images, annotations) + + self.set_block_state(state, block_state) + + return components, state + +``` + +Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines. + + + + +```shell +# In the folder with the `block.py` file, run: +diffusers-cli custom_block +``` + +Then upload the block to the Hub: + +```shell +hf upload . . +``` + + + +```py +from block import Florence2ImageAnnotatorBlock +block = Florence2ImageAnnotatorBlock() +block.push_to_hub("") +``` + + + + +## Using Custom Blocks + +Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`. + +```py +import torch +from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS +from diffusers.utils import load_image + +# Fetch the Florence2 image annotator block that will create our mask +image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True) + +my_blocks = INPAINT_BLOCKS.copy() +# insert the annotation block before the image encoding step +my_blocks.insert("image_annotator", image_annotator_block, 1) + +# Create our initial set of inpainting blocks +blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks) + +repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0" +pipe = blocks.init_pipeline(repo_id) +pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True) + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true") +image = image.resize((1024, 1024)) + +prompt = ["A red car"] +annotation_task = "" +annotation_prompt = ["the car"] + +output = pipe( + prompt=prompt, + image=image, + annotation_task=annotation_task, + annotation_prompt=annotation_prompt, + annotation_output_type="mask_image", + num_inference_steps=35, + guidance_scale=7.5, + strength=0.95, + output="images" +) +output[0].save("florence-inpainting.png") +``` + +## Editing Custom Blocks + +By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder. + +```py +import torch +from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS +from diffusers.utils import load_image + +# Fetch the Florence2 image annotator block that will create our mask +image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder") +``` + +Any changes made to the block files in this folder will be reflected when you load the block again. From 63dd601758b3f520547513e61e88b6cdb42174f2 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Sat, 15 Nov 2025 00:12:24 +0100 Subject: [PATCH 13/35] Improve docstrings and type hints in scheduling_euler_discrete.py (#12654) * refactor: enhance type hints and documentation in EulerDiscreteScheduler Updated type hints for function parameters and return types in the EulerDiscreteScheduler class to improve code clarity and maintainability. Enhanced docstrings for several methods to provide clearer descriptions of their functionality and expected arguments. This includes specifying Literal types for certain parameters and ensuring consistent return type annotations across the class. * refactor: enhance type hints and documentation across multiple schedulers Updated type hints and improved docstrings in various scheduler classes, including CMStochasticIterativeScheduler, CosineDPMSolverMultistepScheduler, and others. This includes specifying parameter types, return types, and providing clearer descriptions of method functionalities. Notable changes include the addition of default values in the begin_index argument and enhanced explanations for noise addition methods. These improvements aim to enhance code clarity and maintainability across the scheduling module. * refactor: update docstrings to clarify noise schedule construction Revised docstrings across multiple scheduler classes to enhance clarity regarding the construction of noise schedules. Updated references to relevant papers, ensuring accurate citations for the methodologies used. This includes changes in DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, and others, improving documentation consistency and readability. --- .../scheduling_consistency_models.py | 44 ++- .../scheduling_cosine_dpmsolver_multistep.py | 30 ++- src/diffusers/schedulers/scheduling_ddim.py | 5 +- .../schedulers/scheduling_ddim_inverse.py | 5 +- .../schedulers/scheduling_ddim_parallel.py | 5 +- src/diffusers/schedulers/scheduling_ddpm.py | 5 +- .../schedulers/scheduling_ddpm_parallel.py | 5 +- .../schedulers/scheduling_deis_multistep.py | 63 ++++- .../scheduling_dpmsolver_multistep.py | 68 ++++- .../scheduling_dpmsolver_multistep_inverse.py | 61 ++++- .../schedulers/scheduling_dpmsolver_sde.py | 90 ++++++- .../scheduling_dpmsolver_singlestep.py | 63 ++++- .../scheduling_edm_dpmsolver_multistep.py | 30 ++- .../schedulers/scheduling_edm_euler.py | 44 ++- .../scheduling_euler_ancestral_discrete.py | 49 +++- .../schedulers/scheduling_euler_discrete.py | 254 ++++++++++++++---- .../scheduling_flow_match_euler_discrete.py | 50 +++- .../scheduling_flow_match_heun_discrete.py | 2 +- .../schedulers/scheduling_flow_match_lcm.py | 50 +++- .../schedulers/scheduling_heun_discrete.py | 105 +++++++- src/diffusers/schedulers/scheduling_ipndm.py | 29 +- .../scheduling_k_dpm_2_ancestral_discrete.py | 105 +++++++- .../schedulers/scheduling_k_dpm_2_discrete.py | 105 +++++++- src/diffusers/schedulers/scheduling_lcm.py | 34 ++- .../schedulers/scheduling_lms_discrete.py | 90 ++++++- .../schedulers/scheduling_sasolver.py | 63 ++++- src/diffusers/schedulers/scheduling_scm.py | 29 +- src/diffusers/schedulers/scheduling_tcd.py | 34 ++- .../schedulers/scheduling_unipc_multistep.py | 68 ++++- 29 files changed, 1426 insertions(+), 159 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 5d81d5eb8ac0..386a43db0f9c 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -121,7 +121,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -287,7 +287,23 @@ def get_scalings_for_boundary_condition(self, sigma): return c_skip, c_out # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -302,7 +318,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -410,6 +433,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py index b9567f2c47d5..7b11d704932b 100644 --- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -137,7 +137,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -266,6 +266,19 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -537,6 +550,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 5ddc46ee4d2f..d7fe29a72ac9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -99,10 +99,11 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index bed424e320be..a7717940e2a1 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -98,10 +98,11 @@ def rescale_zero_terminal_snr(betas): Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 1432d835aea2..d957ade901b3 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -100,10 +100,11 @@ def rescale_zero_terminal_snr(betas): Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5ccf4adaebbc..1d0ad49c58cd 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -97,10 +97,11 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index 8740f14c66b4..78011d0e46a1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -99,10 +99,11 @@ def rescale_zero_terminal_snr(betas): Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 15d8a20e33b8..bf8e1d98d6c0 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -230,7 +230,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -364,6 +364,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -399,7 +412,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -425,7 +451,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -449,7 +487,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index b1f218a5eb06..dee97f39ff68 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -83,10 +83,11 @@ def rescale_zero_terminal_snr(betas): Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas @@ -323,7 +324,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -503,6 +504,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -537,7 +551,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -576,7 +603,19 @@ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -600,7 +639,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 476d2fc10568..0f734aeb54c9 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -376,6 +376,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -411,7 +424,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -437,7 +463,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -461,7 +499,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 2b02c2fd5e57..e22954d4e6ea 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -251,7 +251,23 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -266,7 +282,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -302,7 +325,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -430,6 +453,19 @@ def t_fn(_sigma): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -468,7 +504,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -492,7 +540,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -646,6 +711,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index e7fde2c2ba0d..0b271d7eacb4 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -295,7 +295,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -454,6 +454,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -489,7 +502,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -515,7 +541,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -539,7 +577,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index f748c6c834a3..eeec588e27a3 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -169,7 +169,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -342,6 +342,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -682,6 +695,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index dbeff3de5652..0bf17356a7fa 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -155,7 +155,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -284,7 +284,23 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -299,7 +315,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -413,6 +436,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index b2741c586be2..8f39507301ce 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -100,10 +100,11 @@ def rescale_zero_terminal_snr(betas): Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas @@ -245,7 +246,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -319,7 +320,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -334,7 +351,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -451,6 +475,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index f88b124f04e7..5ea926c4ca38 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -97,16 +97,17 @@ def alpha_bar_fn(t): # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas @@ -146,17 +147,17 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. + `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`, *optional*): + Prediction type of the scheduler function; can be `"epsilon"` (predicts the noise of the diffusion + process), `"sample"` (directly predicts the noisy sample`) or `"v_prediction"` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). - interpolation_type(`str`, defaults to `"linear"`, *optional*): - The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of + interpolation_type (`Literal["linear", "log_linear"]`, defaults to `"linear"`, *optional*): + The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of `"linear"` or `"log_linear"`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, @@ -166,18 +167,26 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): use_beta_sigmas (`bool`, *optional*, defaults to `False`): Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. - timestep_spacing (`str`, defaults to `"linspace"`): + sigma_min (`float`, *optional*): + The minimum sigma value for the noise schedule. If not provided, defaults to the last sigma in the + schedule. + sigma_max (`float`, *optional*): + The maximum sigma value for the noise schedule. If not provided, defaults to the first sigma in the + schedule. + timestep_spacing (`Literal["linspace", "leading", "trailing"]`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + timestep_type (`Literal["discrete", "continuous"]`, defaults to `"discrete"`): + The type of timesteps to use. Can be `"discrete"` or `"continuous"`. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). - final_sigmas_type (`str`, defaults to `"zero"`): + final_sigmas_type (`Literal["zero", "sigma_min"]`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -189,20 +198,20 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - interpolation_type: str = "linear", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", + interpolation_type: Literal["linear", "log_linear"] = "linear", use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, - timestep_spacing: str = "linspace", - timestep_type: str = "discrete", # can be "discrete" or "continuous" + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", + timestep_type: Literal["discrete", "continuous"] = "discrete", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, - final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" + final_sigmas_type: Literal["zero", "sigma_min"] = "zero", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -259,8 +268,15 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def init_noise_sigma(self): - # standard deviation of the initial noise distribution + def init_noise_sigma(self) -> Union[float, torch.Tensor]: + """ + The standard deviation of the initial noise distribution. + + Returns: + `float` or `torch.Tensor`: + The standard deviation of the initial noise distribution, computed based on the maximum sigma value and + the timestep spacing configuration. + """ max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() if self.config.timestep_spacing in ["linspace", "trailing"]: return max_sigma @@ -268,26 +284,34 @@ def init_noise_sigma(self): return (max_sigma**2 + 1) ** 0.5 @property - def step_index(self): + def step_index(self) -> Optional[int]: """ - The index counter for current timestep. It will increase 1 after each scheduler step. + The index counter for current timestep. It will increase by 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index, or `None` if not initialized. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The begin index for the scheduler, or `None` if not set. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -299,13 +323,13 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T Args: sample (`torch.Tensor`): - The input sample. - timestep (`int`, *optional*): + The input sample to be scaled. + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: `torch.Tensor`: - A scaled input sample. + A scaled input sample, divided by `(sigma**2 + 1) ** 0.5`. """ if self.step_index is None: self._init_step_index(timestep) @@ -318,17 +342,18 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T def set_timesteps( self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. If `None`, + `timesteps` or `sigmas` must be provided. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): @@ -336,10 +361,9 @@ def set_timesteps( based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`, and `timestep_spacing` attribute will be ignored. sigmas (`List[float]`, *optional*): - Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas - will be generated based on the relevant scheduler attributes. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the - custom sigmas schedule. + Custom sigmas used to support arbitrary timesteps schedule. If `None`, timesteps and sigmas will be + generated based on the relevant scheduler attributes. If `sigmas` is passed, `num_inference_steps` and + `timesteps` must be `None`, and the timesteps will be generated based on the custom sigmas schedule. """ if timesteps is not None and sigmas is not None: @@ -449,7 +473,20 @@ def set_timesteps( self._begin_index = None self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -473,8 +510,21 @@ def _sigma_to_t(self, sigma, log_sigmas): return t # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -500,7 +550,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26 def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -523,7 +585,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -551,7 +630,23 @@ def _convert_to_beta( ) return sigmas - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -565,7 +660,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -591,26 +693,33 @@ def step( Args: model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`float`): + The direct output from the learned diffusion model. + timestep (`float` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - s_churn (`float`): - s_tmin (`float`): - s_tmax (`float`): - s_noise (`float`, defaults to 1.0): + s_churn (`float`, *optional*, defaults to `0.0`): + Stochasticity parameter that controls the amount of noise added during sampling. Higher values increase + randomness. + s_tmin (`float`, *optional*, defaults to `0.0`): + Minimum timestep threshold for applying stochasticity. Only timesteps above this value will have noise + added. + s_tmax (`float`, *optional*, defaults to `inf`): + Maximum timestep threshold for applying stochasticity. Only timesteps below this value will have noise + added. + s_noise (`float`, *optional*, defaults to `1.0`): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): + A random number generator for reproducible sampling. + return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is - returned, otherwise a tuple is returned where the first element is the sample tensor. + If `return_dict` is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor and the second + element is the predicted original sample. """ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): @@ -689,6 +798,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): @@ -717,6 +841,24 @@ def add_noise( return noisy_samples def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Compute the velocity prediction for the given sample and noise at the specified timesteps. + + This method implements the velocity prediction used in v-prediction models, which predicts a linear combination + of the sample and noise. + + Args: + sample (`torch.Tensor`): + The input sample for which to compute the velocity. + noise (`torch.Tensor`): + The noise tensor corresponding to the sample. + timesteps (`torch.Tensor`): + The timesteps at which to compute the velocity. + + Returns: + `torch.Tensor`: + The velocity prediction computed as `sqrt(alpha_prod) * noise - sqrt(1 - alpha_prod) * sample`. + """ if ( isinstance(timesteps, int) or isinstance(timesteps, torch.IntTensor) @@ -753,5 +895,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 1a4f12ddfa53..9fd61d9e18d1 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -160,7 +160,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -473,7 +473,20 @@ def step( # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -499,7 +512,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -523,7 +548,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py index 38e5f1ba77a8..6febee444c5a 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py @@ -102,7 +102,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index diff --git a/src/diffusers/schedulers/scheduling_flow_match_lcm.py b/src/diffusers/schedulers/scheduling_flow_match_lcm.py index 933bb1cf8e3d..25186d1fe969 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_lcm.py +++ b/src/diffusers/schedulers/scheduling_flow_match_lcm.py @@ -168,7 +168,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -473,7 +473,20 @@ def step( # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -499,7 +512,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -523,7 +548,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index db81fc82bcf3..930b0344646d 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -188,7 +188,23 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -230,7 +246,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -355,6 +371,19 @@ def set_timesteps( # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -379,7 +408,20 @@ def _sigma_to_t(self, sigma, log_sigmas): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -405,7 +447,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -429,7 +483,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -462,7 +533,14 @@ def state_in_first_order(self): return self.dt is None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -580,6 +658,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py index 23bc21f10ca4..da188fe8297c 100644 --- a/src/diffusers/schedulers/scheduling_ipndm.py +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -78,7 +78,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -112,7 +112,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self._begin_index = None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -127,7 +143,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 48cc01e6aac7..595b93c39d4c 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -207,7 +207,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -343,6 +343,19 @@ def set_timesteps( # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -367,7 +380,20 @@ def _sigma_to_t(self, sigma, log_sigmas): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -393,7 +419,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -417,7 +455,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -450,7 +505,23 @@ def state_in_first_order(self): return self.sample is None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -465,7 +536,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -587,6 +665,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index aaf6a48b57be..7db12227229e 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -207,7 +207,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -331,7 +331,23 @@ def state_in_first_order(self): return self.sample is None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -346,7 +362,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -356,6 +379,19 @@ def _init_step_index(self, timestep): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -380,7 +416,20 @@ def _sigma_to_t(self, sigma, log_sigmas): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -406,7 +455,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -430,7 +491,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -559,6 +637,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index 36587537ec1b..a7b0644de4f5 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -102,10 +102,11 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas @@ -251,7 +252,23 @@ def __init__( self._begin_index = None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -266,7 +283,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -291,7 +315,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 6fa9c2f7fbcf..573678b100ba 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -210,7 +210,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -320,7 +320,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.derivatives = [] # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -335,7 +351,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -345,6 +368,19 @@ def _init_step_index(self, timestep): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -383,7 +419,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -407,7 +455,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -522,6 +587,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples to which noise will be added. + noise (`torch.Tensor`): + The noise tensor to add to the original samples. + timesteps (`torch.Tensor`): + The timesteps at which to add noise, determining the noise level from the schedule. + + Returns: + `torch.Tensor`: + The noisy samples with added noise scaled according to the timestep schedule. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 30a3eb294a04..d9054c39c9de 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -254,7 +254,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -386,6 +386,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -421,7 +434,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -447,7 +473,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -471,7 +509,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index 63b4a109ff9b..7b01d886299c 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -109,7 +109,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -173,7 +173,14 @@ def set_timesteps( self._begin_index = None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -182,7 +189,23 @@ def _init_step_index(self, timestep): self._step_index = self._begin_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 101b1569a145..37b41c87f8a2 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -101,10 +101,11 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas @@ -252,7 +253,23 @@ def __init__( self._begin_index = None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index of a given timestep in the timestep schedule. + + Args: + timestep (`float` or `torch.Tensor`): + The timestep value to find in the schedule. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. For the very first step, returns the second index if + multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image). + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -267,7 +284,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: + """ + Initialize the step index for the scheduler based on the given timestep. + + Args: + timestep (`float` or `torch.Tensor`): + The current timestep to initialize the step index from. + """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -292,7 +316,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index d985871df109..7dc5f467680b 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -83,10 +83,11 @@ def rescale_zero_terminal_snr(betas): Args: betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + The betas that the scheduler is being initialized with. Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR + `torch.Tensor`: + Rescaled betas with zero terminal SNR. """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas @@ -297,7 +298,7 @@ def set_begin_index(self, begin_index: int = 0): Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: - begin_index (`int`): + begin_index (`int`, defaults to `0`): The begin index for the scheduler. """ self._begin_index = begin_index @@ -475,6 +476,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): + """ + Convert sigma values to corresponding timestep values through interpolation. + + Args: + sigma (`np.ndarray`): + The sigma value(s) to convert to timestep(s). + log_sigmas (`np.ndarray`): + The logarithm of the sigma schedule used for interpolation. + + Returns: + `np.ndarray`: + The interpolated timestep value(s) corresponding to the input sigma(s). + """ # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -510,7 +524,20 @@ def _sigma_to_alpha_sigma_t(self, sigma): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -536,7 +563,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" + """ + Construct an exponential noise schedule. + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted sigma values following an exponential schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers @@ -560,7 +599,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + """ + Construct a beta noise schedule as proposed in [Beta Sampling is All You + Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `torch.Tensor`: + The converted sigma values following a beta distribution schedule. + """ # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers From a9e4883b6a937bfe9360663d198302ec0f0a204c Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Fri, 14 Nov 2025 16:06:22 -0800 Subject: [PATCH 14/35] Update Wan Animate Docs (#12658) * Update the Wan Animate docs to reflect the most recent code * Further explain input preprocessing and link to original Wan Animate preprocessing scripts --- .../api/models/wan_animate_transformer_3d.md | 2 +- docs/source/en/api/pipelines/wan.md | 48 +++++++------------ 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/docs/source/en/api/models/wan_animate_transformer_3d.md b/docs/source/en/api/models/wan_animate_transformer_3d.md index 798afc72fb8e..cc7b3f0c408c 100644 --- a/docs/source/en/api/models/wan_animate_transformer_3d.md +++ b/docs/source/en/api/models/wan_animate_transformer_3d.md @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet. ```python from diffusers import WanAnimateTransformer3DModel -transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-720P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) ``` ## WanAnimateTransformer3DModel diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 3993e2efd0c8..6aab6c5b33b9 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -281,7 +281,7 @@ For replacement mode, you additionally need: - **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black) > [!NOTE] -> The preprocessing tools are available in the original Wan-Animate repository. Integration of these preprocessing steps into Diffusers is planned for a future release. +> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release. The example below demonstrates how to use the Wan-Animate pipeline: @@ -293,13 +293,10 @@ import numpy as np import torch from diffusers import AutoencoderKLWan, WanAnimatePipeline from diffusers.utils import export_to_video, load_image, load_video -from transformers import CLIPVisionModel model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) -pipe = WanAnimatePipeline.from_pretrained( - model_id, vae=vae, torch_dtype=torch.bfloat16 -) +pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) pipe.to("cuda") # Load character image and preprocessed videos @@ -330,11 +327,11 @@ output = pipe( negative_prompt=negative_prompt, height=height, width=width, - num_frames=81, - guidance_scale=5.0, - mode="animation", # Animation mode (default) + segment_frame_length=77, + guidance_scale=1.0, + mode="animate", # Animation mode (default) ).frames[0] -export_to_video(output, "animated_character.mp4", fps=16) +export_to_video(output, "animated_character.mp4", fps=30) ``` @@ -345,14 +342,10 @@ import numpy as np import torch from diffusers import AutoencoderKLWan, WanAnimatePipeline from diffusers.utils import export_to_video, load_image, load_video -from transformers import CLIPVisionModel model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" -image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) -pipe = WanAnimatePipeline.from_pretrained( - model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 -) +pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) pipe.to("cuda") # Load all required inputs for replacement mode @@ -387,11 +380,11 @@ output = pipe( negative_prompt=negative_prompt, height=height, width=width, - num_frames=81, - guidance_scale=5.0, - mode="replacement", # Replacement mode + segment_frame_lengths=77, + guidance_scale=1.0, + mode="replace", # Replacement mode ).frames[0] -export_to_video(output, "character_replaced.mp4", fps=16) +export_to_video(output, "character_replaced.mp4", fps=30) ``` @@ -402,14 +395,10 @@ import numpy as np import torch from diffusers import AutoencoderKLWan, WanAnimatePipeline from diffusers.utils import export_to_video, load_image, load_video -from transformers import CLIPVisionModel model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" -image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) -pipe = WanAnimatePipeline.from_pretrained( - model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 -) +pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) pipe.to("cuda") image = load_image("path/to/character.jpg") @@ -443,14 +432,14 @@ output = pipe( negative_prompt=negative_prompt, height=height, width=width, - num_frames=81, + segment_frame_length=77, num_inference_steps=50, guidance_scale=5.0, - num_frames_for_temporal_guidance=5, # Use 5 frames for temporal guidance (1 or 5 recommended) + prev_segment_conditioning_frames=5, # Use 5 frames for temporal guidance (1 or 5 recommended) callback_on_step_end=callback_fn, callback_on_step_end_tensor_inputs=["latents"], ).frames[0] -export_to_video(output, "animated_advanced.mp4", fps=16) +export_to_video(output, "animated_advanced.mp4", fps=30) ``` @@ -458,10 +447,9 @@ export_to_video(output, "animated_advanced.mp4", fps=16) #### Key Parameters -- **mode**: Choose between `"animation"` (default) or `"replacement"` -- **num_frames_for_temporal_guidance**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory -- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt -- **num_frames**: Total number of frames to generate. Should be divisible by `vae_scale_factor_temporal` (default: 4) +- **mode**: Choose between `"animate"` (default) or `"replace"` +- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory +- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.) ## Notes From 61f045ab733a09ac39997d9cf7959fddb888dab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 15 Nov 2025 11:20:56 +0300 Subject: [PATCH 15/35] Fix: Add `image_dim` `None` --- src/diffusers/models/transformers/transformer_wan_s2v.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 4d810eeca630..61a0ed3377e9 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -838,6 +838,8 @@ class WanS2VTransformer3DModel( Epsilon value for normalization layers. add_img_emb (`bool`, defaults to `False`): Whether to use img_emb. + image_dim (`int`, *optional*, defaults to `None`): + The dimension of image embeddings. Set to `None` for S2V model as it doesn't use image conditioning. added_kv_proj_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the added key and value projections. If `None`, no projection is used. zero_timestep (`bool`, defaults to `True`): @@ -872,6 +874,7 @@ def __init__( cross_attn_norm: bool = True, qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, + image_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, enable_framepack: bool = True, From 9392d8f352971d4c1fff804411796c47920cd7c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 15 Nov 2025 11:21:23 +0300 Subject: [PATCH 16/35] up --- .../models/transformers/transformer_wan_s2v.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 61a0ed3377e9..ea042ce8b778 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -376,19 +376,21 @@ def __init__( ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion patch_size=(1, 2, 2), + in_channels=16, ): super().__init__() self.inner_dim = inner_dim self.num_attention_heads = num_attention_heads + self.in_channels = in_channels if (inner_dim % num_attention_heads) != 0 or (inner_dim // num_attention_heads) % 2 != 0: raise ValueError( "inner_dim must be divisible by num_attention_heads and inner_dim // num_attention_heads must be even" ) self.drop_mode = drop_mode - self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) - self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) - self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) self.rope = WanS2VRotaryPosEmbed( @@ -401,7 +403,7 @@ def __init__( def forward(self, motion_latents, add_last_motion=2): latent_height, latent_width = motion_latents.shape[3], motion_latents.shape[4] padd_latent = torch.zeros( - (motion_latents.shape[0], 16, self.zip_frame_buckets.sum(), latent_height, latent_width), + (motion_latents.shape[0], self.in_channels, self.zip_frame_buckets.sum(), latent_height, latent_width), device=motion_latents.device, dtype=motion_latents.dtype, ) @@ -898,6 +900,7 @@ def __init__( zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode, patch_size=patch_size, + in_channels=in_channels, ) else: self.motion_in = Motioner( From 1c626eb7d5cb0ee2a68268a817313d000ce5853d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 15 Nov 2025 11:22:31 +0300 Subject: [PATCH 17/35] style --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/transformers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 4 +--- tests/quantization/gguf/test_gguf.py | 2 +- 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 28e99dea1ccd..356ebf891b4d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -981,8 +981,8 @@ UNetSpatioTemporalConditionModel, UVit2DModel, VQModel, - WanS2VTransformer3DModel, WanAnimateTransformer3DModel, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, attention_backend, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 312e990af7fa..dad00021d4fa 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -216,8 +216,8 @@ T5FilmDecoder, Transformer2DModel, TransformerTemporalModel, - WanS2VTransformer3DModel, WanAnimateTransformer3DModel, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d40669295e13..2286c2c120b3 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -42,6 +42,6 @@ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel - from .transformer_wan_s2v import WanS2VTransformer3DModel from .transformer_wan_animate import WanAnimateTransformer3DModel + from .transformer_wan_s2v import WanS2VTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7b3885645c43..153a4113349c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -811,12 +811,10 @@ ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .wan import ( - WanImageToVideoPipeline, - WanPipeline, - WanSpeechToVideoPipeline, WanAnimatePipeline, WanImageToVideoPipeline, WanPipeline, + WanSpeechToVideoPipeline, WanVACEPipeline, WanVideoToVideoPipeline, ) diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index ff466d36aca5..39c5dd4d6a63 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -16,8 +16,8 @@ HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, - WanS2VTransformer3DModel, WanAnimateTransformer3DModel, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) From a480ecc9e71f73f21b555512cda9effa05dbce26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 15 Nov 2025 11:26:27 +0300 Subject: [PATCH 18/35] up --- scripts/convert_wan_to_diffusers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 816bf494118a..ead4924273a4 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -479,8 +479,6 @@ def convert_transformer(model_type: str, stage: str = None): if "S2V" in model_type: transformer = WanS2VTransformer3DModel.from_config(diffusers_config) elif "VACE" not in model_type: - transformer = WanTransformer3DModel.from_config(diffusers_config) - else: transformer = WanVACETransformer3DModel.from_config(diffusers_config) else: transformer = WanTransformer3DModel.from_config(diffusers_config) From 7ea88d9b0097cc2a82efe623e4d5a3e960c3b0f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 15 Nov 2025 11:49:44 +0300 Subject: [PATCH 19/35] revert --- scripts/convert_wan_to_diffusers.py | 230 +++++++++++++++++++++++++++- 1 file changed, 227 insertions(+), 3 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index ead4924273a4..71e6227aa47f 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -120,6 +120,42 @@ "after_proj": "proj_out", } +ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "cross_attn.k_img": "attn2.to_k_img", + "cross_attn.v_img": "attn2.to_v_img", + "cross_attn.norm_k_img": "attn2.norm_k_img", + # After cross_attn -> attn2 rename, we need to rename the img keys + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", + # Wan Animate-specific mappings (motion encoder, face encoder, face adapter) + # Motion encoder mappings + # The name mapping is complicated for the convolutional part so we handle that in its own function + "motion_encoder.enc.fc": "motion_encoder.motion_network", + "motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight", + # Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten + "face_encoder.conv1_local.conv": "face_encoder.conv1_local", + "face_encoder.conv2.conv": "face_encoder.conv2", + "face_encoder.conv3.conv": "face_encoder.conv3", + # Face adapter mappings are handled in a separate function +} + S2V_TRANSFORMER_KEYS_RENAME_DICT = { "time_embedding.0": "condition_embedder.time_embedder.linear_1", "time_embedding.2": "condition_embedder.time_embedder.linear_2", @@ -170,8 +206,149 @@ }, } +# TODO: Verify this and simplify if possible. +def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None: + """ + Convert all motion encoder weights for Animate model. + + In the original model: + - All Linear layers in fc use EqualLinear + - All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately) + - Blur kernels are stored as buffers in Sequential modules + - ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)] + + Conversion strategy: + 1. Drop .kernel buffers (blur kernels) + 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu) + """ + # Skip if not a weight, bias, or kernel + if ".weight" not in key and ".bias" not in key and ".kernel" not in key: + return + + # Handle Blur kernel buffers from original implementation. + # After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel + # Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys + if ".kernel" in key and "motion_encoder" in key: + # Remove unexpected blur kernel buffers to avoid strict load errors + state_dict.pop(key, None) + return + + # Rename Sequential indices to named components in ConvLayer and ResBlock + if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key): + parts = key.split(".") + + # Find the sequential index (digit) after convs or after conv1/conv2/skip + # Examples: + # - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight) + # - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias) + # - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight) + # - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight + # - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias) + # - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias + # - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight) + # - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias) + # - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight) + # - enc.net_app.convs.8 -> conv_out (final conv layer) + + convs_idx = parts.index("convs") if "convs" in parts else -1 + if convs_idx >= 0 and len(parts) - convs_idx >= 2: + bias = False + # The nn.Sequential index will always follow convs + sequential_idx = int(parts[convs_idx + 1]) + if sequential_idx == 0: + if key.endswith(".weight"): + new_key = "motion_encoder.conv_in.weight" + elif key.endswith(".bias"): + new_key = "motion_encoder.conv_in.act_fn.bias" + bias = True + elif sequential_idx == final_conv_idx: + if key.endswith(".weight"): + new_key = "motion_encoder.conv_out.weight" + else: + # Intermediate .convs. layers, which get mapped to .res_blocks. + prefix = "motion_encoder.res_blocks." + + layer_name = parts[convs_idx + 2] + if layer_name == "skip": + layer_name = "conv_skip" + + if key.endswith(".weight"): + param_name = "weight" + elif key.endswith(".bias"): + param_name = "act_fn.bias" + bias = True + + suffix_parts = [str(sequential_idx - 1), layer_name, param_name] + suffix = ".".join(suffix_parts) + new_key = prefix + suffix + + param = state_dict.pop(key) + if bias: + param = param.squeeze() + state_dict[new_key] = param + return + return + return + + +def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None: + """ + Convert face adapter weights for the Animate model. + + The original model uses a fused KV projection but the diffusers models uses separate K and V projections. + """ + # Skip if not a weight or bias + if ".weight" not in key and ".bias" not in key: + return + + prefix = "face_adapter." + if ".fuser_blocks." in key: + parts = key.split(".") + + module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1 + if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3: + block_idx = parts[module_list_idx + 1] + layer_name = parts[module_list_idx + 2] + param_name = parts[module_list_idx + 3] + + if layer_name == "linear1_kv": + layer_name_k = "to_k" + layer_name_v = "to_v" + + suffix_k = ".".join([block_idx, layer_name_k, param_name]) + suffix_v = ".".join([block_idx, layer_name_v, param_name]) + new_key_k = prefix + suffix_k + new_key_v = prefix + suffix_v + + kv_proj = state_dict.pop(key) + k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0) + state_dict[new_key_k] = k_proj + state_dict[new_key_v] = v_proj + return + else: + if layer_name == "q_norm": + new_layer_name = "norm_q" + elif layer_name == "k_norm": + new_layer_name = "norm_k" + elif layer_name == "linear1_q": + new_layer_name = "to_q" + elif layer_name == "linear2": + new_layer_name = "to_out" + + suffix_parts = [block_idx, new_layer_name, param_name] + suffix = ".".join(suffix_parts) + new_key = prefix + suffix + state_dict[new_key] = state_dict.pop(key) + return + return + + TRANSFORMER_SPECIAL_KEYS_REMAP = {} VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} +ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "motion_encoder": convert_animate_motion_encoder_weights, + "face_adapter": convert_animate_face_adapter_weights, +} S2V_TRANSFORMER_SPECIAL_KEYS_REMAP = {} @@ -430,6 +607,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-Animate-14B": + config = { + "model_id": "Wan-AI/Wan2.2-Animate-14B", + "diffusers_config": { + "image_dim": 1280, + "added_kv_proj_dim": 5120, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": (1, 2, 2), + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "rope_max_seq_len": 1024, + "pos_embed_seq_len": None, + "motion_encoder_size": 512, # Start of Wan Animate-specific configs + "motion_style_dim": 512, + "motion_dim": 20, + "motion_encoder_dim": 512, + "face_encoder_hidden_dim": 1024, + "face_encoder_num_heads": 4, + "inject_face_latents_blocks": 5, + }, + } + RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP elif model_type == "Wan2.2-S2V-14B": config = { "model_id": "Wan-AI/Wan2.2-S2V-14B", @@ -444,7 +652,7 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: "num_attention_heads": 40, "num_layers": 40, "out_channels": 16, - "patch_size": [1, 2, 2], + "patch_size": (1, 2, 2), "qk_norm": "rms_norm_across_heads", "text_dim": 4096, "audio_dim": 1024, @@ -478,6 +686,8 @@ def convert_transformer(model_type: str, stage: str = None): with init_empty_weights(): if "S2V" in model_type: transformer = WanS2VTransformer3DModel.from_config(diffusers_config) + elif "Animate" in model_type: + transformer = WanAnimateTransformer3DModel.from_config(diffusers_config) elif "VACE" not in model_type: transformer = WanVACETransformer3DModel.from_config(diffusers_config) else: @@ -1029,7 +1239,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "S2V" not in args.model_type: + if "Wan2.2" in args.model_type and not any(tag in args.model_type for tag in ("TI2V", "Animate", "S2V")): transformer = convert_transformer(args.model_type, stage="high_noise_model") transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") else: @@ -1045,7 +1255,7 @@ def get_args(): tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") if "FLF2V" in args.model_type: flow_shift = 16.0 - elif "TI2V" in args.model_type or "S2V" in args.model_type: + elif any(tag in args.model_type for tag in ("TI2V", "Animate", "S2V")): flow_shift = 5.0 else: flow_shift = 3.0 @@ -1121,6 +1331,20 @@ def get_args(): vae=vae, scheduler=scheduler, ) + elif "Animate" in args.model_type: + image_encoder = CLIPVisionModel.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 + ) + image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + pipe = WanAnimatePipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) elif "S2V" in args.model_type: audio_encoder = Wav2Vec2ForCTC.from_pretrained( "Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english" From ff8f30962b67c51fc3fd62bbd234f01a314cfb9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 15 Nov 2025 11:54:19 +0300 Subject: [PATCH 20/35] update --- scripts/convert_wan_to_diffusers.py | 19 ++++++++++++++++++- src/diffusers/utils/dummy_pt_objects.py | 1 - tests/quantization/gguf/test_gguf.py | 22 ++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 71e6227aa47f..1700980f475e 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -124,6 +124,16 @@ "time_embedding.0": "condition_embedder.time_embedder.linear_1", "time_embedding.2": "condition_embedder.time_embedder.linear_2", "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 "norm2": "norm__placeholder", "norm3": "norm2", "norm__placeholder": "norm3", @@ -134,6 +144,13 @@ # Add attention component mappings "self_attn.q": "attn1.to_q", "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", "cross_attn.o": "attn2.to_out.0", "cross_attn.norm_q": "attn2.norm_q", "cross_attn.norm_k": "attn2.norm_k", @@ -688,7 +705,7 @@ def convert_transformer(model_type: str, stage: str = None): transformer = WanS2VTransformer3DModel.from_config(diffusers_config) elif "Animate" in model_type: transformer = WanAnimateTransformer3DModel.from_config(diffusers_config) - elif "VACE" not in model_type: + elif "VACE" in model_type: transformer = WanVACETransformer3DModel.from_config(diffusers_config) else: transformer = WanTransformer3DModel.from_config(diffusers_config) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 886ef6dee2be..f56a8b932505 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1623,7 +1623,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class WanS2VTransformer3DModel(metaclass=DummyObject): class WanAnimateTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 39c5dd4d6a63..396be246ce59 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -727,6 +727,28 @@ class WanS2VGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-S2V-14B-GGUF/blob/main/Wan2.2-S2V-14B-Q3_K_S.gguf" torch_dtype = torch.bfloat16 model_cls = WanS2VTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states": torch.randn( + (1, 96, 2, 64, 64), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states_scale": torch.randn( + (8,), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + class WanAnimateGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q3_K_S.gguf" torch_dtype = torch.bfloat16 From c677ee739bea8faa714a5622ed8a1082c9ff67d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 15 Nov 2025 11:55:45 +0300 Subject: [PATCH 21/35] style --- scripts/convert_wan_to_diffusers.py | 7 +++---- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- tests/quantization/gguf/test_gguf.py | 1 + 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 1700980f475e..30324f143e38 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -9,14 +9,12 @@ from transformers import ( AutoProcessor, AutoTokenizer, - CLIPVisionModelWithProjection, - UMT5EncoderModel, - Wav2Vec2ForCTC, - Wav2Vec2Processor, CLIPImageProcessor, CLIPVisionModel, CLIPVisionModelWithProjection, UMT5EncoderModel, + Wav2Vec2ForCTC, + Wav2Vec2Processor, ) from diffusers import ( @@ -223,6 +221,7 @@ }, } + # TODO: Verify this and simplify if possible. def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None: """ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 356ebf891b4d..c3b456333fdb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -268,8 +268,8 @@ "UNetSpatioTemporalConditionModel", "UVit2DModel", "VQModel", - "WanS2VTransformer3DModel", "WanAnimateTransformer3DModel", + "WanS2VTransformer3DModel", "WanTransformer3DModel", "WanVACETransformer3DModel", "attention_backend", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dad00021d4fa..89d7debb34b1 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -108,8 +108,8 @@ _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] - _import_structure["transformers.transformer_wan_s2v"] = ["WanS2VTransformer3DModel"] _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] + _import_structure["transformers.transformer_wan_s2v"] = ["WanS2VTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 396be246ce59..98969b55b727 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -749,6 +749,7 @@ def get_dummy_inputs(self): "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), } + class WanAnimateGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q3_K_S.gguf" torch_dtype = torch.bfloat16 From 065d9823115a4a12612022484416dfcef92fbe67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 15 Nov 2025 12:01:56 +0300 Subject: [PATCH 22/35] up --- docs/source/en/api/pipelines/wan.md | 258 ++++++++++++++-------------- 1 file changed, 129 insertions(+), 129 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index fbadc1f191e3..d1cfb8b3588c 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -40,8 +40,8 @@ The following Wan models are supported in Diffusers: - [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) - [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) - [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers) -- [Wan 2.2 S2V 14B](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B-Diffusers) - [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers) +- [Wan 2.2 S2V 14B](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B-Diffusers) > [!TIP] > Click on the Wan models in the right sidebar for more examples of video generation. @@ -239,128 +239,6 @@ export_to_video(output, "output.mp4", fps=16) -### Wan-S2V: Audio-Driven Cinematic Video Generation - -[Wan-S2V](https://huggingface.co/papers/2508.18621) by the Wan Team. - -*Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.* - -The project page: https://humanaigc.github.io/wan-s2v-webpage/ - -This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). - -The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, an audio, and a pose video. - - - - -```python -import numpy as np, math -import torch -from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline -from diffusers.utils import export_to_merged_video_audio, load_image, load_audio, load_video, export_to_video -from transformers import Wav2Vec2ForCTC -import requests -from PIL import Image -from io import BytesIO - - -model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" -audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) -vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) -pipe = WanSpeechToVideoPipeline.from_pretrained( - model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 -) -pipe.to("cuda") - -headers = {"User-Agent": "Mozilla/5.0"} -url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg" -resp = requests.get(url, headers=headers, timeout=30) -image = Image.open(BytesIO(resp.content)) - -audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3") -#pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" - -def get_size_less_than_area(height, - width, - target_area=1024 * 704, - divisor=64): - if height * width <= target_area: - # If the original image area is already less than or equal to the target, - # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. - max_upper_area = target_area - min_scale = 0.1 - max_scale = 1.0 - else: - # Resize to fit within the target area and then pad to multiples of `divisor` - max_upper_area = target_area # Maximum allowed total pixel count after padding - d = divisor - 1 - b = d * (height + width) - a = height * width - c = d**2 - max_upper_area - - # Calculate scale boundaries using quadratic equation - min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied - max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding - - # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area - # Use binary search-like iteration to find this scale - find_it = False - for i in range(100): - scale = max_scale - (max_scale - min_scale) * i / 100 - new_height, new_width = int(height * scale), int(width * scale) - - # Pad to make dimensions divisible by 64 - pad_height = (64 - new_height % 64) % 64 - pad_width = (64 - new_width % 64) % 64 - pad_top = pad_height // 2 - pad_bottom = pad_height - pad_top - pad_left = pad_width // 2 - pad_right = pad_width - pad_left - - padded_height, padded_width = new_height + pad_height, new_width + pad_width - - if padded_height * padded_width <= max_upper_area: - find_it = True - break - - if find_it: - return padded_height, padded_width - else: - # Fallback: calculate target dimensions based on aspect ratio and divisor alignment - aspect_ratio = width / height - target_width = int( - (target_area * aspect_ratio)**0.5 // divisor * divisor) - target_height = int( - (target_area / aspect_ratio)**0.5 // divisor * divisor) - - # Ensure the result is not larger than the original resolution - if target_width >= width or target_height >= height: - target_width = int(width // divisor * divisor) - target_height = int(height // divisor * divisor) - - return target_height, target_width - -height, width = get_size_less_than_area(first_frame.height, first_frame.width, 480*832) - -prompt = "Einstein singing a song." - -output = pipe( - prompt=prompt, image=image, audio=audio, sampling_rate=sampling_rate, - height=height, width=width, num_frames_per_chunk=80, - #pose_video_path_or_url=pose_video_path_or_url, -).frames[0] -export_to_video(output, "output.mp4", fps=16) - -# Lastly, we need to merge the video and audio into a new video, with the duration set to -# the shorter of the two and overwrite the original video file. -export_to_merged_video_audio("output.mp4", "audio.mp3") -``` - - - - - ### Any-to-Video Controllable Generation Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include: @@ -588,6 +466,128 @@ export_to_video(output, "animated_advanced.mp4", fps=16) - **num_frames**: Total number of frames to generate. Should be divisible by `vae_scale_factor_temporal` (default: 4) +### Wan-S2V: Audio-Driven Cinematic Video Generation + +[Wan-S2V](https://huggingface.co/papers/2508.18621) by the Wan Team. + +*Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.* + +The project page: https://humanaigc.github.io/wan-s2v-webpage/ + +This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). + +The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, an audio, and a pose video. + + + + +```python +import numpy as np, math +import torch +from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline +from diffusers.utils import export_to_merged_video_audio, load_image, load_audio, load_video, export_to_video +from transformers import Wav2Vec2ForCTC +import requests +from PIL import Image +from io import BytesIO + + +model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" +audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = WanSpeechToVideoPipeline.from_pretrained( + model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +headers = {"User-Agent": "Mozilla/5.0"} +url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg" +resp = requests.get(url, headers=headers, timeout=30) +image = Image.open(BytesIO(resp.content)) + +audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3") +#pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" + +def get_size_less_than_area(height, + width, + target_area=1024 * 704, + divisor=64): + if height * width <= target_area: + # If the original image area is already less than or equal to the target, + # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. + max_upper_area = target_area + min_scale = 0.1 + max_scale = 1.0 + else: + # Resize to fit within the target area and then pad to multiples of `divisor` + max_upper_area = target_area # Maximum allowed total pixel count after padding + d = divisor - 1 + b = d * (height + width) + a = height * width + c = d**2 - max_upper_area + + # Calculate scale boundaries using quadratic equation + min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied + max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding + + # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area + # Use binary search-like iteration to find this scale + find_it = False + for i in range(100): + scale = max_scale - (max_scale - min_scale) * i / 100 + new_height, new_width = int(height * scale), int(width * scale) + + # Pad to make dimensions divisible by 64 + pad_height = (64 - new_height % 64) % 64 + pad_width = (64 - new_width % 64) % 64 + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + + padded_height, padded_width = new_height + pad_height, new_width + pad_width + + if padded_height * padded_width <= max_upper_area: + find_it = True + break + + if find_it: + return padded_height, padded_width + else: + # Fallback: calculate target dimensions based on aspect ratio and divisor alignment + aspect_ratio = width / height + target_width = int( + (target_area * aspect_ratio)**0.5 // divisor * divisor) + target_height = int( + (target_area / aspect_ratio)**0.5 // divisor * divisor) + + # Ensure the result is not larger than the original resolution + if target_width >= width or target_height >= height: + target_width = int(width // divisor * divisor) + target_height = int(height // divisor * divisor) + + return target_height, target_width + +height, width = get_size_less_than_area(first_frame.height, first_frame.width, 480*832) + +prompt = "Einstein singing a song." + +output = pipe( + prompt=prompt, image=image, audio=audio, sampling_rate=sampling_rate, + height=height, width=width, num_frames_per_chunk=80, + #pose_video_path_or_url=pose_video_path_or_url, +).frames[0] +export_to_video(output, "output.mp4", fps=16) + +# Lastly, we need to merge the video and audio into a new video, with the duration set to +# the shorter of the two and overwrite the original video file. +export_to_merged_video_audio("output.mp4", "audio.mp3") +``` + + + + + ## Notes - Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`]. @@ -692,12 +692,6 @@ export_to_video(output, "animated_advanced.mp4", fps=16) - all - __call__ -## WanSpeechToVideoPipeline - -[[autodoc]] WanSpeechToVideoPipeline - - all - - __call__ - ## WanVideoToVideoPipeline [[autodoc]] WanVideoToVideoPipeline @@ -710,6 +704,12 @@ export_to_video(output, "animated_advanced.mp4", fps=16) - all - __call__ +## WanSpeechToVideoPipeline + +[[autodoc]] WanSpeechToVideoPipeline + - all + - __call__ + ## WanPipelineOutput [[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput From 01a56927f1603f1e89d1e5ada74d2aa75da2d46b Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Sat, 15 Nov 2025 16:14:34 +0100 Subject: [PATCH 23/35] Rope in float32 for mps or npu compatibility (#12665) rope in float32 --- src/diffusers/models/transformers/transformer_prx.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index 9b2664b9cb26..ccbc83ffca03 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -275,7 +275,12 @@ def __init__(self, dim: int, theta: int, axes_dim: List[int]): def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + + is_mps = pos.device.type == "mps" + is_npu = pos.device.type == "npu" + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + + scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) out = pos.unsqueeze(-1) * omega.unsqueeze(0) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) From 0c35b580fe83d223fbdf2a66d11167837abdd115 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 17 Nov 2025 06:07:40 +0100 Subject: [PATCH 24/35] [PRX pipeline]: add 1024 resolution ratio bins (#12670) add 1024 ratio bins --- src/diffusers/pipelines/prx/pipeline_prx.py | 43 +++++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py index df598a5715d2..873f25316e6d 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -69,6 +69,39 @@ "2.0": [704, 352], } +ASPECT_RATIO_1024_BIN = { + "0.49": [704, 1440], + "0.52": [736, 1408], + "0.53": [736, 1376], + "0.57": [768, 1344], + "0.59": [768, 1312], + "0.62": [800, 1280], + "0.67": [832, 1248], + "0.68": [832, 1216], + "0.78": [896, 1152], + "0.83": [928, 1120], + "0.94": [992, 1056], + "1.0": [1024, 1024], + "1.06": [1056, 992], + "1.13": [1088, 960], + "1.21": [1120, 928], + "1.29": [1152, 896], + "1.37": [1184, 864], + "1.46": [1216, 832], + "1.5": [1248, 832], + "1.71": [1312, 768], + "1.75": [1344, 768], + "1.87": [1376, 736], + "1.91": [1408, 736], + "2.05": [1440, 704], +} + +ASPECT_RATIO_BINS = { + 256: ASPECT_RATIO_256_BIN, + 512: ASPECT_RATIO_512_BIN, + 1024: ASPECT_RATIO_1024_BIN, +} + logger = logging.get_logger(__name__) @@ -600,10 +633,12 @@ def __call__( "Resolution binning requires a VAE with image_processor, but VAE is not available. " "Set use_resolution_binning=False or provide a VAE." ) - if self.default_sample_size <= 256: - aspect_ratio_bin = ASPECT_RATIO_256_BIN - else: - aspect_ratio_bin = ASPECT_RATIO_512_BIN + if self.default_sample_size not in ASPECT_RATIO_BINS: + raise ValueError( + f"Resolution binning is only supported for default_sample_size in {list(ASPECT_RATIO_BINS.keys())}, " + f"but got {self.default_sample_size}. Set use_resolution_binning=False to disable aspect ratio binning." + ) + aspect_ratio_bin = ASPECT_RATIO_BINS[self.default_sample_size] # Store original dimensions orig_height, orig_width = height, width From 1afc21855eb1f5575bd61037a7ee44522ccf401e Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Mon, 17 Nov 2025 16:23:34 +0800 Subject: [PATCH 25/35] SANA-Video Image to Video pipeline `SanaImageToVideoPipeline` support (#12634) * move sana-video to a new dir and add `SanaImageToVideoPipeline` with no modify; * fix bug and run text/image-to-vidoe success; * make style; quality; fix-copies; * add sana image-to-video pipeline in markdown; * add test case for sana image-to-video; * make style; * add a init file in sana-video test dir; * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update tests/pipelines/sana_video/test_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update tests/pipelines/sana_video/test_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * minor update; * fix bug and skip fp16 save test; Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * add copied from for `encode_prompt` * Apply style fixes --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com> Co-authored-by: github-actions[bot] --- docs/source/en/api/pipelines/sana_video.md | 90 +- scripts/convert_sana_video_to_diffusers.py | 3 + src/diffusers/__init__.py | 3 + .../transformers/transformer_sana_video.py | 14 +- src/diffusers/pipelines/__init__.py | 5 +- src/diffusers/pipelines/sana/__init__.py | 2 - .../pipelines/sana/pipeline_output.py | 16 - .../pipelines/sana_video/__init__.py | 49 + .../pipelines/sana_video/pipeline_output.py | 20 + .../pipeline_sana_video.py | 14 +- .../sana_video/pipeline_sana_video_i2v.py | 1066 +++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/sana_video/__init__.py | 0 .../{sana => sana_video}/test_sana_video.py | 0 .../sana_video/test_sana_video_i2v.py | 238 ++++ 15 files changed, 1501 insertions(+), 34 deletions(-) create mode 100644 src/diffusers/pipelines/sana_video/__init__.py create mode 100644 src/diffusers/pipelines/sana_video/pipeline_output.py rename src/diffusers/pipelines/{sana => sana_video}/pipeline_sana_video.py (98%) create mode 100644 src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py create mode 100644 tests/pipelines/sana_video/__init__.py rename tests/pipelines/{sana => sana_video}/test_sana_video.py (100%) create mode 100644 tests/pipelines/sana_video/test_sana_video_i2v.py diff --git a/docs/source/en/api/pipelines/sana_video.md b/docs/source/en/api/pipelines/sana_video.md index 85d77fb2944b..d69f4a95facc 100644 --- a/docs/source/en/api/pipelines/sana_video.md +++ b/docs/source/en/api/pipelines/sana_video.md @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# SanaVideoPipeline +# Sana-Video
LoRA @@ -37,6 +37,85 @@ Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-vi Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. + +## Generation Pipelines + +` + + +The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame. + +```python +model_id = +pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16) +pipe.text_encoder.to(torch.bfloat16) +pipe.vae.to(torch.float32) +pipe.to("cuda") + +prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." +motion_scale = 30 +motion_prompt = f" motion score: {motion_scale}." +prompt = prompt + motion_prompt + +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + frames=81, + guidance_scale=6, + num_inference_steps=50, + generator=torch.Generator(device="cuda").manual_seed(0), +).frames[0] + +export_to_video(video, "sana_video.mp4", fps=16) +``` + + + + +The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame. + +```python +model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers" +pipe = SanaImageToVideoPipeline.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, +) +pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0) +pipe.vae.to(torch.float32) +pipe.text_encoder.to(torch.bfloat16) +pipe.to("cuda") + +image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png") +prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle." +negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." +motion_scale = 30 +motion_prompt = f" motion score: {motion_scale}." +prompt = prompt + motion_prompt + +motion_scale = 30.0 + +video = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + frames=81, + guidance_scale=6, + num_inference_steps=50, + generator=torch.Generator(device="cuda").manual_seed(0), +).frames[0] + +export_to_video(video, "sana-i2v.mp4", fps=16) +``` + + + + + ## Quantization Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model. @@ -97,6 +176,13 @@ export_to_video(output, "sana-video-output.mp4", fps=16) - __call__ +## SanaImageToVideoPipeline + +[[autodoc]] SanaImageToVideoPipeline + - all + - __call__ + + ## SanaVideoPipelineOutput -[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput +[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index fbb7c1d9e706..a939a06cbd46 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -80,6 +80,8 @@ def main(args): # scheduler flow_shift = 8.0 + if args.task == "i2v": + assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task." # model config layer_num = 20 @@ -312,6 +314,7 @@ def main(args): choices=["flow-dpm_solver", "flow-euler", "uni-pc"], help="Scheduler type to use.", ) + parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.") parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.") parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 02df34c07e8e..cd7a2cb581b7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -545,11 +545,13 @@ "QwenImagePipeline", "ReduxImageEncoder", "SanaControlNetPipeline", + "SanaImageToVideoPipeline", "SanaPAGPipeline", "SanaPipeline", "SanaSprintImg2ImgPipeline", "SanaSprintPipeline", "SanaVideoPipeline", + "SanaVideoPipeline", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -1227,6 +1229,7 @@ QwenImagePipeline, ReduxImageEncoder, SanaControlNetPipeline, + SanaImageToVideoPipeline, SanaPAGPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index f9fc971950b0..a4f90342631a 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -237,7 +237,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin -# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm class SanaModulatedNorm(nn.Module): def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6): super().__init__() @@ -247,7 +246,7 @@ def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor ) -> torch.Tensor: hidden_states = self.norm(hidden_states) - shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1) + shift, scale = (scale_shift_table[None, None] + temb[:, :, None].to(scale_shift_table.device)).unbind(dim=2) hidden_states = hidden_states * (1 + scale) + shift return hidden_states @@ -423,8 +422,8 @@ def forward( # 1. Modulation shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) + self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1) + ).unbind(dim=2) # 2. Self Attention norm_hidden_states = self.norm1(hidden_states) @@ -635,13 +634,16 @@ def forward( if guidance is not None: timestep, embedded_timestep = self.time_embed( - timestep, guidance=guidance, hidden_dtype=hidden_states.dtype + timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype ) else: timestep, embedded_timestep = self.time_embed( - timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype ) + timestep = timestep.view(batch_size, -1, timestep.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 719ff4c7df15..69bb14b98edc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -308,7 +308,10 @@ "SanaSprintPipeline", "SanaControlNetPipeline", "SanaSprintImg2ImgPipeline", + ] + _import_structure["sana_video"] = [ "SanaVideoPipeline", + "SanaImageToVideoPipeline", ] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] @@ -749,8 +752,8 @@ SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline, - SanaVideoPipeline, ) + from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py index d5571ab12fac..91684f35f153 100644 --- a/src/diffusers/pipelines/sana/__init__.py +++ b/src/diffusers/pipelines/sana/__init__.py @@ -26,7 +26,6 @@ _import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"] _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"] _import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"] - _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -40,7 +39,6 @@ from .pipeline_sana_controlnet import SanaControlNetPipeline from .pipeline_sana_sprint import SanaSprintPipeline from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline - from .pipeline_sana_video import SanaVideoPipeline else: import sys diff --git a/src/diffusers/pipelines/sana/pipeline_output.py b/src/diffusers/pipelines/sana/pipeline_output.py index 8021b7738755..f8ac12951644 100644 --- a/src/diffusers/pipelines/sana/pipeline_output.py +++ b/src/diffusers/pipelines/sana/pipeline_output.py @@ -3,7 +3,6 @@ import numpy as np import PIL.Image -import torch from ...utils import BaseOutput @@ -20,18 +19,3 @@ class SanaPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] - - -@dataclass -class SanaVideoPipelineOutput(BaseOutput): - r""" - Output class for Sana-Video pipelines. - - Args: - frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): - List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing - denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape - `(batch_size, num_frames, channels, height, width)`. - """ - - frames: torch.Tensor diff --git a/src/diffusers/pipelines/sana_video/__init__.py b/src/diffusers/pipelines/sana_video/__init__.py new file mode 100644 index 000000000000..73e224bf749d --- /dev/null +++ b/src/diffusers/pipelines/sana_video/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"] + _import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_sana_video import SanaVideoPipeline + from .pipeline_sana_video_i2v import SanaImageToVideoPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/sana_video/pipeline_output.py b/src/diffusers/pipelines/sana_video/pipeline_output.py new file mode 100644 index 000000000000..4d37923889eb --- /dev/null +++ b/src/diffusers/pipelines/sana_video/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from ...utils import BaseOutput + + +@dataclass +class SanaVideoPipelineOutput(BaseOutput): + r""" + Output class for Sana-Video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py similarity index 98% rename from src/diffusers/pipelines/sana/pipeline_sana_video.py rename to src/diffusers/pipelines/sana_video/pipeline_sana_video.py index 5ec498faffb9..a786275e45a9 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py @@ -95,17 +95,16 @@ >>> from diffusers import SanaVideoPipeline >>> from diffusers.utils import export_to_video - >>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers" - >>> pipe = SanaVideoPipeline.from_pretrained(model_id) + >>> pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers") >>> pipe.transformer.to(torch.bfloat16) >>> pipe.text_encoder.to(torch.bfloat16) >>> pipe.vae.to(torch.float32) >>> pipe.to("cuda") - >>> model_score = 30 + >>> motion_score = 30 >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." - >>> motion_prompt = f" motion score: {model_score}." + >>> motion_prompt = f" motion score: {motion_score}." >>> prompt = prompt + motion_prompt >>> output = pipe( @@ -231,6 +230,7 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], @@ -827,9 +827,9 @@ def __call__( Examples: Returns: - [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated videos + [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated videos """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py new file mode 100644 index 000000000000..e87880b64cee --- /dev/null +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py @@ -0,0 +1,1066 @@ +# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SanaVideoPipelineOutput +from .pipeline_sana_video import ASPECT_RATIO_480_BIN, ASPECT_RATIO_720_BIN + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = SanaImageToVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers") + >>> pipe.transformer.to(torch.bfloat16) + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.vae.to(torch.float32) + >>> pipe.to("cuda") + >>> motion_score = 30 + + >>> prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle." + >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." + >>> motion_prompt = f" motion score: {motion_score}." + >>> prompt = prompt + motion_prompt + >>> image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png") + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... frames=81, + ... guidance_scale=6, + ... num_inference_steps=50, + ... generator=torch.Generator(device="cuda").manual_seed(42), + ... ).frames[0] + + >>> export_to_video(output, "sana-ti2v-output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for image/text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model + inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all + pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]): + The tokenizer used to tokenize the prompt. + text_encoder ([`Gemma2PreTrainedModel`]): + Text encoder model to encode the input prompts. + vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer ([`SanaVideoTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, + vae: Union[AutoencoderDC, AutoencoderKLWan], + transformer: SanaVideoTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + self.vae_scale_factor = self.vae_scale_factor_spatial + + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size[1] if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size[0] if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana_video.pipeline_sana_video.SanaVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of videos that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + image = image.to(device=device, dtype=self.vae.dtype) + + if isinstance(generator, list): + image_latents = [retrieve_latents(self.vae.encode(image), sample_mode="argmax") for _ in generator] + image_latents = torch.cat(image_latents) + else: + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, -1, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + image_latents = (image_latents - latents_mean) * latents_std + + latents[:, :, 0:1] = image_latents.to(dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + height: int = 480, + width: int = 832, + frames: int = 81, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: List[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> Union[SanaVideoPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the video generation on. The first frame of the generated video will be + conditioned on this image. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to 480): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 832): + The width in pixels of the generated video. + frames (`int`, *optional*, defaults to 81): + The number of frames in the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between mp4 or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos, + they are resized back to the requested resolution. Useful for generating non-square videos. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`List[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated videos + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 30: + aspect_ratio_bin = ASPECT_RATIO_480_BIN + elif self.transformer.config.sample_size == 22: + aspect_ratio_bin = ASPECT_RATIO_720_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + image, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + + latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + frames, + torch.float32, + device, + generator, + latents, + ) + + conditioning_mask = latents.new_zeros( + batch_size, + 1, + latents.shape[2] // self.transformer_temporal_patch_size, + latents.shape[3] // self.transformer_spatial_patch_size, + latents.shape[4] // self.transformer_spatial_patch_size, + ) + conditioning_mask[:, :, 0] = 1.0 + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(conditioning_mask.shape) + timestep = timestep * (1 - conditioning_mask) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + noise_pred = noise_pred[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step( + noise_pred, t, noise_latents, **extra_step_kwargs, return_dict=False + )[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + try: + video = self.vae.decode(latents, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + + if use_resolution_binning: + video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height) + + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SanaVideoPipelineOutput(frames=video) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9e32232133f3..9eb123b94e9d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2147,6 +2147,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SanaImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SanaPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/sana_video/__init__.py b/tests/pipelines/sana_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana_video/test_sana_video.py similarity index 100% rename from tests/pipelines/sana/test_sana_video.py rename to tests/pipelines/sana_video/test_sana_video.py diff --git a/tests/pipelines/sana_video/test_sana_video_i2v.py b/tests/pipelines/sana_video/test_sana_video_i2v.py new file mode 100644 index 000000000000..36a646ca528f --- /dev/null +++ b/tests/pipelines/sana_video/test_sana_video_i2v.py @@ -0,0 +1,238 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + SanaImageToVideoPipeline, + SanaVideoTransformer3DModel, +) + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class SanaImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SanaImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + text_encoder_config = Gemma2Config( + head_dim=16, + hidden_size=8, + initializer_range=0.02, + intermediate_size=64, + max_position_embeddings=8192, + model_type="gemma2", + num_attention_heads=2, + num_hidden_layers=1, + num_key_value_heads=2, + vocab_size=8, + attn_implementation="eager", + ) + text_encoder = Gemma2Model(text_encoder_config) + tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") + + torch.manual_seed(0) + transformer = SanaVideoTransformer3DModel( + in_channels=16, + out_channels=16, + num_attention_heads=2, + attention_head_dim=12, + num_layers=2, + num_cross_attention_heads=2, + cross_attention_head_dim=12, + cross_attention_dim=24, + caption_channels=8, + mlp_ratio=2.5, + dropout=0.0, + attention_bias=False, + sample_size=8, + patch_size=(1, 2, 2), + norm_elementwise_affine=False, + norm_eps=1e-6, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Create a dummy image input (PIL Image) + image = Image.new("RGB", (32, 32)) + + inputs = { + "image": image, + "prompt": "", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + "complex_human_instruction": [], + "use_resolution_binning": False, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + def test_save_load_local(self, expected_max_difference=5e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass + + @unittest.skip("Skipping fp16 test as model is trained with bf16") + def test_float16_inference(self): + # Requires higher tolerance as model seems very sensitive to dtype + super().test_float16_inference(expected_max_diff=0.08) + + @unittest.skip("Skipping fp16 test as model is trained with bf16") + def test_save_load_float16(self): + # Requires higher tolerance as model seems very sensitive to dtype + super().test_save_load_float16(expected_max_diff=0.2) + + +@slow +@require_torch_accelerator +class SanaVideoPipelineIntegrationTests(unittest.TestCase): + prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + @unittest.skip("TODO: test needs to be implemented") + def test_sana_video_480p(self): + pass From 3579fdabf905fbcb0952466088573a381f676a55 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 17 Nov 2025 14:23:09 +0530 Subject: [PATCH 26/35] [CI] Make CI logs less verbose (#12674) update --- .github/workflows/nightly_tests.yml | 14 +++++++------- .github/workflows/pr_modular_tests.yml | 2 +- .github/workflows/pr_tests.yml | 8 ++++---- .github/workflows/pr_tests_gpu.yml | 10 +++++----- .github/workflows/push_tests.yml | 10 +++++----- .github/workflows/push_tests_fast.yml | 2 +- .github/workflows/push_tests_mps.yml | 2 +- .github/workflows/release_tests_fast.yml | 12 ++++++------ 8 files changed, 30 insertions(+), 30 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 1738efd63bb7..0f1920aded35 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -84,7 +84,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ --report-log=tests_pipeline_${{ matrix.module }}_cuda.log \ tests/pipelines/${{ matrix.module }} @@ -138,7 +138,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_torch_${{ matrix.module }}_cuda \ --report-log=tests_torch_${{ matrix.module }}_cuda.log \ tests/${{ matrix.module }} @@ -151,7 +151,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v --make-reports=examples_torch_cuda \ + --make-reports=examples_torch_cuda \ --report-log=examples_torch_cuda.log \ examples/ @@ -198,7 +198,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} RUN_COMPILE: yes run: | - pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/ + pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/ - name: Failure short reports if: ${{ failure() }} run: cat reports/tests_torch_compile_cuda_failures_short.txt @@ -293,7 +293,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_torch_minimum_version_cuda \ tests/models/test_modeling_common.py \ tests/pipelines/test_pipelines_common.py \ @@ -531,7 +531,7 @@ jobs: # HF_HOME: /System/Volumes/Data/mnt/cache # HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # run: | -# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \ +# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \ # --report-log=tests_torch_mps.log \ # tests/ # - name: Failure short reports @@ -587,7 +587,7 @@ jobs: # HF_HOME: /System/Volumes/Data/mnt/cache # HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # run: | -# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \ +# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \ # --report-log=tests_torch_mps.log \ # tests/ # - name: Failure short reports diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml index 7081ee518d55..c32d144220f5 100644 --- a/.github/workflows/pr_modular_tests.yml +++ b/.github/workflows/pr_modular_tests.yml @@ -120,7 +120,7 @@ jobs: if: ${{ matrix.config.framework == 'pytorch_pipelines' }} run: | pytest -n 8 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_${{ matrix.config.report }} \ tests/modular_pipelines diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 3306ebe43ef7..5a6648ae4d78 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -126,7 +126,7 @@ jobs: if: ${{ matrix.config.framework == 'pytorch_pipelines' }} run: | pytest -n 8 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_${{ matrix.config.report }} \ tests/pipelines @@ -134,7 +134,7 @@ jobs: if: ${{ matrix.config.framework == 'pytorch_models' }} run: | pytest -n 4 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx and not Dependency" \ + -k "not Flax and not Onnx and not Dependency" \ --make-reports=tests_${{ matrix.config.report }} \ tests/models tests/schedulers tests/others @@ -255,11 +255,11 @@ jobs: - name: Run fast PyTorch LoRA tests with PEFT run: | pytest -n 4 --max-worker-restart=0 --dist=loadfile \ - -s -v \ + \ --make-reports=tests_peft_main \ tests/lora/ pytest -n 4 --max-worker-restart=0 --dist=loadfile \ - -s -v \ + \ --make-reports=tests_models_lora_peft_main \ tests/models/ -k "lora" diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index 6c208ad7cac7..95bbb5a033c0 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -151,13 +151,13 @@ jobs: run: | if [ "${{ matrix.module }}" = "ip_adapters" ]; then pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ tests/pipelines/${{ matrix.module }} else pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }}) pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx and $pattern" \ + -k "not Flax and not Onnx and $pattern" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ tests/pipelines/${{ matrix.module }} fi @@ -222,10 +222,10 @@ jobs: run: | pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }}) if [ -z "$pattern" ]; then - pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \ + pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \ --make-reports=tests_torch_cuda_${{ matrix.module }} else - pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \ + pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \ --make-reports=tests_torch_cuda_${{ matrix.module }} fi @@ -274,7 +274,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | uv pip install ".[training]" - pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/ + pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/ - name: Failure short reports if: ${{ failure() }} diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 58133a7f43df..12bcb062d511 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -87,7 +87,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ tests/pipelines/${{ matrix.module }} - name: Failure short reports @@ -141,7 +141,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_torch_cuda_${{ matrix.module }} \ tests/${{ matrix.module }} @@ -189,7 +189,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} RUN_COMPILE: yes run: | - pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/ + pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/ - name: Failure short reports if: ${{ failure() }} run: cat reports/tests_torch_compile_cuda_failures_short.txt @@ -230,7 +230,7 @@ jobs: env: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | - pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/ + pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/ - name: Failure short reports if: ${{ failure() }} run: cat reports/tests_torch_xformers_cuda_failures_short.txt @@ -273,7 +273,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | uv pip install ".[training]" - pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/ + pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/ - name: Failure short reports if: ${{ failure() }} diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index ae619d481c48..38cbffaa6315 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -70,7 +70,7 @@ jobs: if: ${{ matrix.config.framework == 'pytorch' }} run: | pytest -n 4 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_${{ matrix.config.report }} \ tests/ diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index 484c7a8eeb49..2d6feb592815 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -57,7 +57,7 @@ jobs: HF_HOME: /System/Volumes/Data/mnt/cache HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ + ${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/ - name: Failure short reports if: ${{ failure() }} diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml index 808818beada3..efdd6ea2b651 100644 --- a/.github/workflows/release_tests_fast.yml +++ b/.github/workflows/release_tests_fast.yml @@ -84,7 +84,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ tests/pipelines/${{ matrix.module }} - name: Failure short reports @@ -137,7 +137,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_torch_${{ matrix.module }}_cuda \ tests/${{ matrix.module }} @@ -187,7 +187,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_torch_minimum_cuda \ tests/models/test_modeling_common.py \ tests/pipelines/test_pipelines_common.py \ @@ -240,7 +240,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} RUN_COMPILE: yes run: | - pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/ + pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/ - name: Failure short reports if: ${{ failure() }} run: cat reports/tests_torch_compile_cuda_failures_short.txt @@ -281,7 +281,7 @@ jobs: env: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | - pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/ + pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/ - name: Failure short reports if: ${{ failure() }} run: cat reports/tests_torch_xformers_cuda_failures_short.txt @@ -326,7 +326,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} run: | uv pip install ".[training]" - pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/ + pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/ - name: Failure short reports if: ${{ failure() }} From 67dc65e2e33fe937deeaa7bc1005cab585393e42 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 17 Nov 2025 05:09:53 -0800 Subject: [PATCH 27/35] Revert `AutoencoderKLWan`'s `dim_mult` default value back to list (#12640) Revert dim_mult back to list and fix type annotation --- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 5b4b74543ae3..b0b2960aaf18 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -971,7 +971,7 @@ def __init__( base_dim: int = 96, decoder_base_dim: Optional[int] = None, z_dim: int = 16, - dim_mult: Tuple[int, ...] = (1, 2, 4, 4), + dim_mult: List[int] = [1, 2, 4, 4], num_res_blocks: int = 2, attn_scales: List[float] = [], temperal_downsample: List[bool] = [False, True, True], From b7df4a5387fd65edf4f16fd5fc8c87a7c815a4c7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 18 Nov 2025 14:43:06 +0530 Subject: [PATCH 28/35] [CI] Temporarily pin transformers (#12677) * update * update * update * update --- .github/workflows/nightly_tests.yml | 20 +++++++++++++++++--- .github/workflows/pr_modular_tests.yml | 3 ++- .github/workflows/pr_tests.yml | 6 ++++-- .github/workflows/pr_tests_gpu.yml | 9 ++++++--- .github/workflows/push_tests.yml | 9 ++++++--- 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 0f1920aded35..8b7e57e91297 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -73,6 +73,8 @@ jobs: run: | uv pip install -e ".[quality]" uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 uv pip install pytest-reportlog - name: Environment run: | @@ -84,7 +86,7 @@ jobs: CUBLAS_WORKSPACE_CONFIG: :16:8 run: | pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -k "not Flax and not Onnx" \ + -k "not Flax and not Onnx" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ --report-log=tests_pipeline_${{ matrix.module }}_cuda.log \ tests/pipelines/${{ matrix.module }} @@ -126,6 +128,8 @@ jobs: uv pip install -e ".[quality]" uv pip install peft@git+https://github.com/huggingface/peft.git uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 uv pip install pytest-reportlog - name: Environment run: python utils/print_env.py @@ -190,6 +194,8 @@ jobs: - name: Install dependencies run: | uv pip install -e ".[quality,training]" + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | python utils/print_env.py @@ -232,6 +238,8 @@ jobs: uv pip install -e ".[quality]" uv pip install peft@git+https://github.com/huggingface/peft.git uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 uv pip install pytest-reportlog - name: Environment run: | @@ -281,6 +289,8 @@ jobs: uv pip install -e ".[quality]" uv pip install peft@git+https://github.com/huggingface/peft.git uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | @@ -358,6 +368,8 @@ jobs: uv pip install ${{ join(matrix.config.additional_deps, ' ') }} fi uv pip install pytest-reportlog + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | python utils/print_env.py @@ -405,6 +417,8 @@ jobs: run: | uv pip install -e ".[quality]" uv pip install -U bitsandbytes optimum_quanto + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 uv pip install pytest-reportlog - name: Environment run: | @@ -531,7 +545,7 @@ jobs: # HF_HOME: /System/Volumes/Data/mnt/cache # HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # run: | -# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \ +# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \ # --report-log=tests_torch_mps.log \ # tests/ # - name: Failure short reports @@ -587,7 +601,7 @@ jobs: # HF_HOME: /System/Volumes/Data/mnt/cache # HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # run: | -# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \ +# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \ # --report-log=tests_torch_mps.log \ # tests/ # - name: Failure short reports diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml index c32d144220f5..13c228621f5c 100644 --- a/.github/workflows/pr_modular_tests.yml +++ b/.github/workflows/pr_modular_tests.yml @@ -109,7 +109,8 @@ jobs: - name: Install dependencies run: | uv pip install -e ".[quality]" - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps - name: Environment diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 5a6648ae4d78..674e62ff443a 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -115,7 +115,8 @@ jobs: - name: Install dependencies run: | uv pip install -e ".[quality]" - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps - name: Environment @@ -246,7 +247,8 @@ jobs: uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps uv pip install -U tokenizers uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index 95bbb5a033c0..369c7a607737 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -131,7 +131,8 @@ jobs: run: | uv pip install -e ".[quality]" uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | @@ -201,7 +202,8 @@ jobs: uv pip install -e ".[quality]" uv pip install peft@git+https://github.com/huggingface/peft.git uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | @@ -262,7 +264,8 @@ jobs: nvidia-smi - name: Install dependencies run: | - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 uv pip install -e ".[quality,training]" - name: Environment diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 12bcb062d511..6bf2516d5880 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -76,7 +76,8 @@ jobs: run: | uv pip install -e ".[quality]" uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | python utils/print_env.py @@ -128,7 +129,8 @@ jobs: uv pip install -e ".[quality]" uv pip install peft@git+https://github.com/huggingface/peft.git uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | @@ -180,7 +182,8 @@ jobs: - name: Install dependencies run: | uv pip install -e ".[quality,training]" - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | python utils/print_env.py From ab71f3c864902fd80110e48b8affee8a00989119 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Nov 2025 08:19:00 +0530 Subject: [PATCH 29/35] [core] Refactor hub attn kernels (#12475) * refactor how attention kernels from hub are used. * up * refactor according to Dhruv's ideas. Co-authored-by: Dhruv Nair * empty Co-authored-by: Dhruv Nair * empty Co-authored-by: Dhruv Nair * empty Co-authored-by: dn6 * up --------- Co-authored-by: Dhruv Nair Co-authored-by: Dhruv Nair --- src/diffusers/models/attention_dispatch.py | 67 +++++++++++++++------- src/diffusers/models/modeling_utils.py | 8 ++- src/diffusers/utils/constants.py | 1 - src/diffusers/utils/kernels_utils.py | 23 -------- tests/others/test_attention_backends.py | 1 - 5 files changed, 54 insertions(+), 46 deletions(-) delete mode 100644 src/diffusers/utils/kernels_utils.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 92a4a6a59936..8504504981a3 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -16,6 +16,7 @@ import functools import inspect import math +from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -42,7 +43,7 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS if TYPE_CHECKING: @@ -82,24 +83,11 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None - if _CAN_USE_AITER_ATTN: from aiter import flash_attn_func as aiter_flash_attn_func else: aiter_flash_attn_func = None -if DIFFUSERS_ENABLE_HUB_KERNELS: - if not is_kernels_available(): - raise ImportError( - "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." - ) - from ..utils.kernels_utils import _get_fa3_from_hub - - flash_attn_interface_hub = _get_fa3_from_hub() - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func -else: - flash_attn_3_func_hub = None - if _CAN_USE_SAGE_ATTN: from sageattention import ( sageattn, @@ -261,6 +249,25 @@ def _is_context_parallel_available( return supports_context_parallel +@dataclass +class _HubKernelConfig: + """Configuration for downloading and using a hub-based attention kernel.""" + + repo_id: str + function_attr: str + revision: Optional[str] = None + kernel_fn: Optional[Callable] = None + + +# Registry for hub-based attention kernels +_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = { + # TODO: temporary revision for now. Remove when merged upstream into `main`. + AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( + repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" + ) +} + + @contextlib.contextmanager def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): """ @@ -415,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None # TODO: add support Hub variant of FA3 varlen later elif backend in [AttentionBackendName._FLASH_3_HUB]: - if not DIFFUSERS_ENABLE_HUB_KERNELS: - raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." - ) if not is_kernels_available(): raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) elif backend == AttentionBackendName.AITER: @@ -571,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx +# ===== Helpers for downloading kernels ===== +def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: + if backend not in _HUB_KERNELS_REGISTRY: + return + config = _HUB_KERNELS_REGISTRY[backend] + + if config.kernel_fn is not None: + return + + try: + from kernels import get_kernel + + kernel_module = get_kernel(config.repo_id, revision=config.revision) + kernel_func = getattr(kernel_module, config.function_attr) + + # Cache the downloaded kernel function in the config object + config.kernel_fn = kernel_func + + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") + raise + + # ===== torch op registrations ===== # Registrations are required for fullgraph tracing compatibility # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding @@ -1418,7 +1444,8 @@ def _flash_attention_3_hub( return_attn_probs: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = flash_attn_3_func_hub( + func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn + out = func( q=query, k=key, v=value, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e4a8f30e721f..f06822c741ca 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None: attention as backend. """ from .attention import AttentionModuleMixin - from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements + from .attention_dispatch import ( + AttentionBackendName, + _check_attention_backend_requirements, + _maybe_download_kernel_for_backend, + ) # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention @@ -606,8 +610,10 @@ def set_attention_backend(self, backend: str) -> None: available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + backend = AttentionBackendName(backend) _check_attention_backend_requirements(backend) + _maybe_download_kernel_for_backend(backend) attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index a18f28606b3e..c46fa4363483 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -46,7 +46,6 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES -DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py deleted file mode 100644 index 26d6e3972fb7..000000000000 --- a/src/diffusers/utils/kernels_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -from ..utils import get_logger -from .import_utils import is_kernels_available - - -logger = get_logger(__name__) - - -_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" - - -def _get_fa3_from_hub(): - if not is_kernels_available(): - return None - else: - from kernels import get_kernel - - try: - # TODO: temporary revision for now. Remove when merged upstream into `main`. - flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") - return flash_attn_3_hub - except Exception as e: - logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") - raise diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 2e5a2fc82bb6..273e3f9c0721 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -7,7 +7,6 @@ ```bash export RUN_ATTENTION_BACKEND_TESTS=yes -export DIFFUSERS_ENABLE_HUB_KERNELS=yes pytest tests/others/test_attention_backends.py ``` From 6d8973ffe28f14176a350e819929e469fc85cb91 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 19 Nov 2025 09:30:04 +0530 Subject: [PATCH 30/35] [CI] Fix indentation issue in workflow files (#12685) update --- .github/workflows/pr_tests_gpu.yml | 16 ++++++++-------- .github/workflows/push_tests.yml | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index 369c7a607737..468979d379c1 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -1,4 +1,4 @@ -name: Fast GPU Tests on PR +name: Fast GPU Tests on PR on: pull_request: @@ -71,7 +71,7 @@ jobs: if: ${{ failure() }} run: | echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY - + setup_torch_cuda_pipeline_matrix: needs: [check_code_quality, check_repository_consistency] name: Setup Torch Pipelines CUDA Slow Tests Matrix @@ -132,7 +132,7 @@ jobs: uv pip install -e ".[quality]" uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | @@ -150,18 +150,18 @@ jobs: # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms CUBLAS_WORKSPACE_CONFIG: :16:8 run: | - if [ "${{ matrix.module }}" = "ip_adapters" ]; then + if [ "${{ matrix.module }}" = "ip_adapters" ]; then pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -k "not Flax and not Onnx" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ tests/pipelines/${{ matrix.module }} - else + else pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }}) pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -k "not Flax and not Onnx and $pattern" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ tests/pipelines/${{ matrix.module }} - fi + fi - name: Failure short reports if: ${{ failure() }} @@ -225,10 +225,10 @@ jobs: pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }}) if [ -z "$pattern" ]; then pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \ - --make-reports=tests_torch_cuda_${{ matrix.module }} + --make-reports=tests_torch_cuda_${{ matrix.module }} else pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \ - --make-reports=tests_torch_cuda_${{ matrix.module }} + --make-reports=tests_torch_cuda_${{ matrix.module }} fi - name: Failure short reports diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 6bf2516d5880..7b1c441d3dc0 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -77,7 +77,7 @@ jobs: uv pip install -e ".[quality]" uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 + uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 - name: Environment run: | python utils/print_env.py From 7d005fc19566a5e7570d62fffd2ee41fbfbe564e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 19 Nov 2025 10:23:03 +0300 Subject: [PATCH 31/35] up --- src/diffusers/models/transformers/transformer_wan_s2v.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index ea042ce8b778..8440f4346968 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -840,8 +840,6 @@ class WanS2VTransformer3DModel( Epsilon value for normalization layers. add_img_emb (`bool`, defaults to `False`): Whether to use img_emb. - image_dim (`int`, *optional*, defaults to `None`): - The dimension of image embeddings. Set to `None` for S2V model as it doesn't use image conditioning. added_kv_proj_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the added key and value projections. If `None`, no projection is used. zero_timestep (`bool`, defaults to `True`): @@ -876,7 +874,6 @@ def __init__( cross_attn_norm: bool = True, qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, - image_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, enable_framepack: bool = True, From a96b145304c36d47228cd24f2200acc4d30ce604 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 19 Nov 2025 21:19:24 +0530 Subject: [PATCH 32/35] [CI] Fix failing Pipeline CPU tests (#12681) update Co-authored-by: Sayak Paul --- tests/pipelines/audioldm2/test_audioldm2.py | 56 +++++++++---------- .../kandinsky2_2/test_kandinsky_combined.py | 3 +- .../kandinsky2_2/test_kandinsky_inpaint.py | 2 + .../test_stable_diffusion_latent_upscale.py | 2 + 4 files changed, 33 insertions(+), 30 deletions(-) diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index 14ff1272a29e..5ccba1dabbfe 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -21,11 +21,9 @@ import pytest import torch from transformers import ( - ClapAudioConfig, ClapConfig, ClapFeatureExtractor, ClapModel, - ClapTextConfig, GPT2Config, GPT2LMHeadModel, RobertaTokenizer, @@ -111,33 +109,33 @@ def get_dummy_components(self): latent_channels=4, ) torch.manual_seed(0) - text_branch_config = ClapTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=8, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=1, - num_hidden_layers=1, - pad_token_id=1, - vocab_size=1000, - projection_dim=8, - ) - audio_branch_config = ClapAudioConfig( - spec_size=8, - window_size=4, - num_mel_bins=8, - intermediate_size=37, - layer_norm_eps=1e-05, - depths=[1, 1], - num_attention_heads=[1, 1], - num_hidden_layers=1, - hidden_size=192, - projection_dim=8, - patch_size=2, - patch_stride=2, - patch_embed_input_channels=4, - ) + text_branch_config = { + "bos_token_id": 0, + "eos_token_id": 2, + "hidden_size": 8, + "intermediate_size": 37, + "layer_norm_eps": 1e-05, + "num_attention_heads": 1, + "num_hidden_layers": 1, + "pad_token_id": 1, + "vocab_size": 1000, + "projection_dim": 8, + } + audio_branch_config = { + "spec_size": 8, + "window_size": 4, + "num_mel_bins": 8, + "intermediate_size": 37, + "layer_norm_eps": 1e-05, + "depths": [1, 1], + "num_attention_heads": [1, 1], + "num_hidden_layers": 1, + "hidden_size": 192, + "projection_dim": 8, + "patch_size": 2, + "patch_stride": 2, + "patch_embed_input_channels": 4, + } text_encoder_config = ClapConfig( text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16 ) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py index 476fc584cc56..62f5853da9a5 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py @@ -23,7 +23,7 @@ KandinskyV22InpaintCombinedPipeline, ) -from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device +from ...testing_utils import enable_full_determinism, require_accelerator, require_torch_accelerator, torch_device from ..test_pipelines_common import PipelineTesterMixin from .test_kandinsky import Dummies from .test_kandinsky_img2img import Dummies as Img2ImgDummies @@ -402,6 +402,7 @@ def test_save_load_local(self): def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=5e-4) + @require_accelerator def test_sequential_cpu_offload_forward_pass(self): super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py index d4eb650263af..8a693e9c2dd0 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py @@ -37,6 +37,7 @@ load_image, load_numpy, numpy_cosine_similarity_distance, + require_accelerator, require_torch_accelerator, slow, torch_device, @@ -254,6 +255,7 @@ def test_model_cpu_offload_forward_pass(self): def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=5e-4) + @require_accelerator def test_sequential_cpu_offload_forward_pass(self): super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index 2e4b428dfeb5..285c2fea7ebc 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -37,6 +37,7 @@ floats_tensor, load_image, load_numpy, + require_accelerator, require_torch_accelerator, slow, torch_device, @@ -222,6 +223,7 @@ def test_stable_diffusion_latent_upscaler_multiple_init_images(self): def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=7e-3) + @require_accelerator def test_sequential_cpu_offload_forward_pass(self): super().test_sequential_cpu_offload_forward_pass(expected_max_diff=3e-3) From 15370f84121008efdd99d41ff7dfeac0bfaeeb04 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Wed, 19 Nov 2025 18:36:41 +0100 Subject: [PATCH 33/35] Improve docstrings and type hints in scheduling_pndm.py (#12676) * Enhance docstrings and type hints in PNDMScheduler class - Updated parameter descriptions to include default values and specific types using Literal for better clarity. - Improved docstring formatting and consistency across methods, including detailed explanations for the `_get_prev_sample` method. - Added type hints for method return types to enhance code readability and maintainability. * Refactor docstring in PNDMScheduler class to enhance clarity - Simplified the explanation of the method for computing the previous sample from the current sample. - Updated the reference to the PNDM paper for better accessibility. - Removed redundant notation explanations to streamline the documentation. --- src/diffusers/schedulers/scheduling_pndm.py | 72 +++++++++++---------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index aded6c224671..651532b06ddb 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - num_train_timesteps (`int`, defaults to 1000): + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to `0.0001`): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to `0.02`): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. skip_prk_steps (`bool`, defaults to `False`): @@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. - prediction_type (`str`, defaults to `epsilon`, *optional*): + prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) - or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) - paper). - timestep_spacing (`str`, defaults to `"leading"`): + or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper). + timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): + steps_offset (`int`, defaults to `0`): An offset added to the inference steps, as required by some model families. """ @@ -117,12 +115,12 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, - prediction_type: str = "epsilon", - timestep_spacing: str = "leading", + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading", steps_offset: int = 0, ): if trained_betas is not None: @@ -164,7 +162,7 @@ def __init__( self.plms_timesteps = None self.timesteps = None - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -243,7 +241,7 @@ def step( The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: @@ -276,14 +274,13 @@ def step_prk( The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. - """ if self.num_inference_steps is None: raise ValueError( @@ -335,14 +332,13 @@ def step_plms( The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. - """ if self.num_inference_steps is None: raise ValueError( @@ -403,19 +399,27 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens """ return sample - def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): - # See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778 - # this function computes x_(t−δ) using the formula of (9) - # Note that x_t needs to be added to both sides of the equation - - # Notation ( -> - # alpha_prod_t -> α_t - # alpha_prod_t_prev -> α_(t−δ) - # beta_prod_t -> (1 - α_t) - # beta_prod_t_prev -> (1 - α_(t−δ)) - # sample -> x_t - # model_output -> e_θ(x_t, t) - # prev_sample -> x_(t−δ) + def _get_prev_sample( + self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor + ) -> torch.Tensor: + """ + Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM + paper](https://huggingface.co/papers/2202.09778). + + Args: + sample (`torch.Tensor`): + The current sample x_t. + timestep (`int`): + The current timestep t. + prev_timestep (`int`): + The previous timestep (t-δ). + model_output (`torch.Tensor`): + The model output e_θ(x_t, t). + + Returns: + `torch.Tensor`: + The previous sample x_(t-δ). + """ alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t @@ -489,5 +493,5 @@ def add_noise( noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps From d5da453de56fe73e0cfd26204ccca441af568ca1 Mon Sep 17 00:00:00 2001 From: Pratim Dasude Date: Thu, 20 Nov 2025 00:48:46 +0530 Subject: [PATCH 34/35] Community Pipeline: FluxFillControlNetInpaintPipeline for FLUX Fill-Based Inpainting with ControlNet (#12649) * new flux fill controlnet inpaint pipline * Delete src/diffusers/pipelines/flux/pipline_flux_fill_controlnet_Inpaint.py deleting from main flux pipeline * Fluc_fill_controlnet community pipline * Update README.md * Apply style fixes --- examples/community/README.md | 105 +- .../pipline_flux_fill_controlnet_Inpaint.py | 1319 +++++++++++++++++ 2 files changed, 1423 insertions(+), 1 deletion(-) create mode 100644 examples/community/pipline_flux_fill_controlnet_Inpaint.py diff --git a/examples/community/README.md b/examples/community/README.md index 69e9c7576103..4ff9c4d77704 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar | FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) | | Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)| | Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) | - +| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -5527,3 +5527,106 @@ images = pipe( ).images images[0].save("pizzeria.png") ``` + +# Flux Fill ControlNet Pipeline + +This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition. + +While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input. + +## Example Usage + + +```python +import torch +from diffusers import ( + FluxControlNetModel, + FluxPriorReduxPipeline, +) +from diffusers.utils import load_image + +# NEW PIPELINE (updated name) +from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline + +device = "cuda" if torch.cuda.is_available() else "cpu" +dtype = torch.bfloat16 + +# Models +base_model = "black-forest-labs/FLUX.1-Fill-dev" +controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0" +prior_model = "black-forest-labs/FLUX.1-Redux-dev" + +# Load ControlNet +controlnet = FluxControlNetModel.from_pretrained( + controlnet_model, + torch_dtype=dtype, +) + +# Load Fill + ControlNet Pipeline +fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained( + base_model, + controlnet=controlnet, + torch_dtype=dtype, +).to(device) + +# OPTIONAL FP8 +# fill_pipe.transformer.enable_layerwise_casting( +# storage_dtype=torch.float8_e4m3fn, +# compute_dtype=torch.bfloat16 +# ) + +# OPTIONAL Prior Redux +#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( +# prior_model, +# torch_dtype=dtype, +#).to(device) + +# Inputs + +# combined_image = load_image("person_input.png") + + +# 1. Prior conditioning +#prior_out = pipe_prior_redux( +# image=cloth_image, +# prompt=cloth_prompt, +#) + +# 2. Fill Inpaint with ControlNet + +# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). + +img = load_image(r"imgs/background.jpg") +mask = load_image(r"imgs/mask.png") + +control_image_depth = load_image(r"imgs/dog_depth _2.png") + +result = fill_pipe( + prompt="a dog on a bench", + image=img, + mask_image=mask, + + control_image=control_image_depth, + control_mode=[2], # union mode + control_guidance_start=0.0, + control_guidance_end=0.8, + controlnet_conditioning_scale=0.9, + + height=1024, + width=1024, + + strength=1.0, + guidance_scale=50.0, + num_inference_steps=60, + max_sequence_length=512, + +# **prior_out, +) + +# result.images[0].save("flux_fill_controlnet_inpaint.png") + +from datetime import datetime +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg") +``` + diff --git a/examples/community/pipline_flux_fill_controlnet_Inpaint.py b/examples/community/pipline_flux_fill_controlnet_Inpaint.py new file mode 100644 index 000000000000..6b1c204df03b --- /dev/null +++ b/examples/community/pipline_flux_fill_controlnet_Inpaint.py @@ -0,0 +1,1319 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlNetInpaintPipeline + >>> from diffusers.models import FluxControlNetModel + >>> from diffusers.utils import load_image + + >>> controlnet = FluxControlNetModel.from_pretrained( + ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = FluxControlNetInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> control_image = load_image( + ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" + ... ) + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + ... ) + >>> mask_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + ... ) + + >>> prompt = "A girl holding a sign that says InstantX" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=0.7, + ... strength=0.7, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux_controlnet_inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def retrieve_latents_fill( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux controlnet pipeline for inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image", "mask", "masked_image_latents"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) + + self.register_modules( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + controlnet=controlnet, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_mask_latents_fill( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # 1. calculate the height and width of the latents + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + # 2. encode the masked image + if masked_image.shape[1] == num_channels_latents: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents_fill(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + batch_size = batch_size * num_images_per_prompt + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # 4. pack the masked_image_latents + # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4 + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + + # 5.resize mask to latents shape we we concatenate the mask to the latents + mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed) + mask = mask.view( + batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor + ) # batch_size, height, 8, width, 8 + mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width + mask = mask.reshape( + batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width + ) # batch_size, 8*8, height, width + + # 6. pack the mask: + # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2 + mask = self._pack_latents( + mask, + batch_size, + self.vae_scale_factor * self.vae_scale_factor, + height, + width, + ) + mask = mask.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + padding_mask_crop: Optional[int] = None, + sigmas: Optional[List[float]] = None, + num_inference_steps: int = 28, + guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image(s) to inpaint. + mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels + will be preserved. + masked_image_latents (`torch.FloatTensor`, *optional*): + Pre-generated masked image latents. + control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The ControlNet input condition. Image to control the generation. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.6): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. + padding_mask_crop (`int`, *optional*): + The size of the padding to use when cropping the mask. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `List[int]`, *optional*): + The mode for the ControlNet. If multiple ControlNets are used, this should be a list. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original transformer. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the joint attention mechanism. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during the inference. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be generated. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + global_height = height + global_width = width + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + # 3. Encode input prompt + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5. Prepare control image + # num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_latents = self.vae.config.latent_channels + + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + # vae encode + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # set control mode + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + if self.controlnet.nets[0].input_hint_block is None: + # vae encode + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + # set control mode + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + # 6. Prepare timesteps + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( + int(global_width) // self.vae_scale_factor // 2 + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 7. Prepare latent variables + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Prepare mask latents + mask_condition = self.mask_processor.preprocess( + mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + + mask_imagee = self.mask_processor.preprocess(mask_image, height=height, width=width) + masked_imagee = init_image * (1 - mask_imagee) + masked_imagee = masked_imagee.to(dtype=self.vae.dtype, device=device) + maskkk, masked_image_latentsss = self.prepare_mask_latents_fill( + mask_imagee, + masked_imagee, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # predict the noise residual + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + if use_guidance: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + masked_image_latents_fill = torch.cat((masked_image_latentsss, maskkk), dim=-1) + latent_model_input = torch.cat([latents, masked_image_latents_fill], dim=2) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # For inpainting, we need to apply the mask and add the masked image latents + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Post-processing + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) From 6f1042e36cd588a7b66498f45c3bb7085e4fa395 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Fri, 21 Nov 2025 19:18:09 +0100 Subject: [PATCH 35/35] Improve docstrings and type hints in scheduling_lms_discrete.py (#12678) * Enhance type hints and docstrings in LMSDiscreteScheduler class Updated type hints for function parameters and return types to improve code clarity and maintainability. Enhanced docstrings for several methods, providing clearer descriptions of their functionality and expected arguments. Notable changes include specifying Literal types for certain parameters and ensuring consistent return type annotations across the class. * docs: Add specific paper reference to `_convert_to_karras` docstring. * Refactor `_convert_to_karras` docstring in DPMSolverSDEScheduler to include detailed descriptions and a specific paper reference, enhancing clarity and documentation consistency. --- .../schedulers/scheduling_dpmsolver_sde.py | 15 +++- .../schedulers/scheduling_lms_discrete.py | 84 +++++++++++++------ 2 files changed, 71 insertions(+), 28 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index e22954d4e6ea..ef89feb1cad6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -488,9 +488,20 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t - # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + # Copied from diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 573678b100ba..d0766eed1b66 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -99,15 +99,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - num_train_timesteps (`int`, defaults to 1000): + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to `0.0001`): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to `0.02`): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): @@ -118,14 +117,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): use_beta_sigmas (`bool`, *optional*, defaults to `False`): Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. - prediction_type (`str`, defaults to `epsilon`, *optional*): + prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). - timestep_spacing (`str`, defaults to `"linspace"`): + timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): + steps_offset (`int`, defaults to `0`): An offset added to the inference steps, as required by some model families. """ @@ -138,13 +137,13 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, - prediction_type: str = "epsilon", - timestep_spacing: str = "linspace", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", steps_offset: int = 0, ): if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: @@ -183,7 +182,15 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def init_noise_sigma(self): + def init_noise_sigma(self) -> Union[float, torch.Tensor]: + """ + The standard deviation of the initial noise distribution. + + Returns: + `float` or `torch.Tensor`: + The standard deviation of the initial noise distribution, computed based on the maximum sigma value and + the timestep spacing configuration. + """ # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() @@ -191,21 +198,29 @@ def init_noise_sigma(self): return (self.sigmas.max() ** 2 + 1) ** 0.5 @property - def step_index(self): + def step_index(self) -> Optional[int]: """ - The index counter for current timestep. It will increase 1 after each scheduler step. + The index counter for current timestep. It will increase by 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index, or `None` if not initialized. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The begin index for the scheduler, or `None` if not set. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -239,14 +254,21 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T self.is_scale_input_called = True return sample - def get_lms_coefficient(self, order, t, current_order): + def get_lms_coefficient(self, order: int, t: int, current_order: int) -> float: """ Compute the linear multistep coefficient. Args: - order (): - t (): - current_order (): + order (`int`): + The order of the linear multistep method. + t (`int`): + The current timestep index. + current_order (`int`): + The current order for which to compute the coefficient. + + Returns: + `float`: + The computed linear multistep coefficient. """ def lms_derivative(tau): @@ -261,7 +283,7 @@ def lms_derivative(tau): return integrated_coeff - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -367,7 +389,7 @@ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: self._step_index = self._begin_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -403,9 +425,19 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t - # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`torch.Tensor`): + The input sigma values to be converted. + + Returns: + `torch.Tensor`: + The converted sigma values following the Karras noise schedule. + """ sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() @@ -629,5 +661,5 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps