From a40d776b41082d1d0864f0473c9b1c8d218bc067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 29 Aug 2025 13:22:54 +0300 Subject: [PATCH 001/131] temp --- .../pipelines/wan/pipeline_wan_s2v.py | 824 ++++++++++++++++++ 1 file changed, 824 insertions(+) create mode 100644 src/diffusers/pipelines/wan/pipeline_wan_s2v.py diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py new file mode 100644 index 000000000000..b7fd0b05980f --- /dev/null +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -0,0 +1,824 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL +import regex as re +import torch +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> import numpy as np + >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import CLIPVisionModel + + >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers + >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + >>> image_encoder = CLIPVisionModel.from_pretrained( + ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 + ... ) + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = WanImageToVideoPipeline.from_pretrained( + ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> max_area = 480 * 832 + >>> aspect_ratio = image.height / image.width + >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + >>> image = image.resize((width, height)) + >>> prompt = ( + ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " + ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + ... ) + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=height, + ... width=width, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only + `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + image_processor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModel = None, + transformer: WanTransformer3DModel = None, + transformer_2: WanTransformer3DModel = None, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + transformer_2=transformer_2, + ) + self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if self.config.boundary_ratio is not None and image_embeds is not None: + raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + + if self.config.expand_timesteps: + video_condition = image + + elif last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + if self.config.expand_timesteps: + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device + ) + first_frame_mask[:, :, 0] = 0 + return latents, latent_condition, first_frame_mask + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # only wan 2.1 i2v transformer accepts image_embeds + if self.transformer is not None and self.transformer.config.image_dim is not None: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + + latents_outputs = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + last_image, + ) + if self.config.expand_timesteps: + # wan 2.2 5b i2v use firt_frame_mask to mask timesteps + latents, condition, first_frame_mask = latents_outputs + else: + latents, condition = latents_outputs + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + if self.config.expand_timesteps: + latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(transformer_dtype) + + # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size) + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if self.config.expand_timesteps: + latents = (1 - first_frame_mask) * condition + first_frame_mask * latents + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) From 4be705f274c6b030dcf1d77080823eca40e5a266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 29 Aug 2025 14:27:29 +0300 Subject: [PATCH 002/131] template2 --- .../transformers/transformer_wan_s2v.py | 362 ++++++++++++++++++ .../pipelines/wan/pipeline_wan_s2v.py | 27 +- 2 files changed, 370 insertions(+), 19 deletions(-) create mode 100644 src/diffusers/models/transformers/transformer_wan_s2v.py diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py new file mode 100644 index 000000000000..13f0c152c1aa --- /dev/null +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -0,0 +1,362 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..cache_utils import CacheMixin +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm +from .transformer_wan import ( + WanAttention, + WanAttnProcessor, + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanS2VTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + apply_input_projection: bool = False, + apply_output_projection: bool = False, + ): + super().__init__() + + # 1. Input projection + self.proj_in = None + if apply_input_projection: + self.proj_in = nn.Linear(dim, dim) + + # 2. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + processor=WanAttnProcessor(), + ) + + # 3. Cross-attention + self.attn2 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + processor=WanAttnProcessor(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 4. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + # 5. Output projection + self.proj_out = None + if apply_output_projection: + self.proj_out = nn.Linear(dim, dim) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + control_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + if self.proj_in is not None: + control_hidden_states = self.proj_in(control_hidden_states) + control_hidden_states = control_hidden_states + hidden_states + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( + control_hidden_states + ) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + control_hidden_states = control_hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + control_hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as( + control_hidden_states + ) + + conditioning_states = None + if self.proj_out is not None: + conditioning_states = self.proj_out(control_hidden_states) + + return conditioning_states, control_hidden_states + + +class WanS2VTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): + r""" + A Transformer model for video-like data used in the Wan model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["WanS2VTransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanS2VTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + control_hidden_states: torch.Tensor = None, + control_hidden_states_scale: torch.Tensor = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + if control_hidden_states_scale is None: + control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers)) + control_hidden_states_scale = torch.unbind(control_hidden_states_scale) + if len(control_hidden_states_scale) != len(self.config.vace_layers): + raise ValueError( + f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be " + f"equal to {len(self.config.vace_layers)}." + ) + + # 1. Rotary position embedding + rotary_emb = self.rope(hidden_states) + + # 2. Patch embedding + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + control_hidden_states = self.vace_patch_embedding(control_hidden_states) + control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2) + control_hidden_states_padding = control_hidden_states.new_zeros( + batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2) + ) + control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1) + + # 3. Time embedding + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + # 4. Image embedding + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 5. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + # Prepare VACE hints + control_hidden_states_list = [] + for i, block in enumerate(self.vace_blocks): + conditioning_states, control_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb + ) + control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) + control_hidden_states_list = control_hidden_states_list[::-1] + + for i, block in enumerate(self.blocks): + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + if i in self.config.vace_layers: + control_hint, scale = control_hidden_states_list.pop() + hidden_states = hidden_states + control_hint * scale + else: + # Prepare VACE hints + control_hidden_states_list = [] + for i, block in enumerate(self.vace_blocks): + conditioning_states, control_hidden_states = block( + hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb + ) + control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) + control_hidden_states_list = control_hidden_states_list[::-1] + + for i, block in enumerate(self.blocks): + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + if i in self.config.vace_layers: + control_hint, scale = control_hidden_states_list.pop() + hidden_states = hidden_states + control_hint * scale + + # 6. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index b7fd0b05980f..68e2baf81a10 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin -from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...models import AutoencoderKLWan, WanS2VTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -124,9 +124,9 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" - Pipeline for image-to-video generation using Wan. + Pipeline for image-and-sound-to-video generation using Wan. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -149,20 +149,11 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - transformer_2 ([`WanTransformer3DModel`], *optional*): - Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, - `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only - `transformer` is used. - boundary_ratio (`float`, *optional*, defaults to `None`): - Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. - The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, - `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < - boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. """ - model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->audio_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"] + _optional_components = ["transformer", "image_encoder", "image_processor"] def __init__( self, @@ -172,9 +163,7 @@ def __init__( scheduler: FlowMatchEulerDiscreteScheduler, image_processor: CLIPImageProcessor = None, image_encoder: CLIPVisionModel = None, - transformer: WanTransformer3DModel = None, - transformer_2: WanTransformer3DModel = None, - boundary_ratio: Optional[float] = None, + transformer: WanS2VTransformer3DModel = None, expand_timesteps: bool = False, ): super().__init__() @@ -187,9 +176,9 @@ def __init__( transformer=transformer, scheduler=scheduler, image_processor=image_processor, - transformer_2=transformer_2, + audio_encoder=audio_encoder ) - self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps) + self.register_to_config(expand_timesteps=expand_timesteps) self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 From cd1824599c95420f363d6b18b230c6aba79a1d24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 29 Aug 2025 15:28:03 +0300 Subject: [PATCH 003/131] up --- scripts/convert_wan_to_diffusers.py | 37 ++++++++++++++++++- src/diffusers/__init__.py | 4 ++ src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_wan_s2v.py | 3 +- src/diffusers/pipelines/__init__.py | 16 +++++++- src/diffusers/pipelines/wan/__init__.py | 2 + .../pipelines/wan/pipeline_wan_s2v.py | 1 - 8 files changed, 59 insertions(+), 7 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 599c90be57ce..ab237a2a3b45 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -13,6 +13,8 @@ UniPCMultistepScheduler, WanImageToVideoPipeline, WanPipeline, + WanS2VTransformer3DModel, + WanSpeechToVideoPipeline, WanTransformer3DModel, WanVACEPipeline, WanVACETransformer3DModel, @@ -341,6 +343,27 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-S2V-14B": + config = { + "model_id": "Wan-AI/Wan2.2-S2V-14B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_channels": 48, + "num_attention_heads": 24, + "num_layers": 30, + "out_channels": 48, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP return config, RENAME_DICT, SPECIAL_KEYS_REMAP @@ -357,7 +380,9 @@ def convert_transformer(model_type: str, stage: str = None): original_state_dict = load_sharded_safetensors(model_dir) with init_empty_weights(): - if "VACE" not in model_type: + if "S2V" in model_type: + transformer = WanS2VTransformer3DModel.from_config(diffusers_config) + elif "VACE" not in model_type: transformer = WanTransformer3DModel.from_config(diffusers_config) else: transformer = WanVACETransformer3DModel.from_config(diffusers_config) @@ -903,7 +928,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - if "Wan2.2" in args.model_type and "TI2V" not in args.model_type: + if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "S2V" not in args.model_type: transformer = convert_transformer(args.model_type, stage="high_noise_model") transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") else: @@ -983,6 +1008,14 @@ def get_args(): vae=vae, scheduler=scheduler, ) + elif "S2V" in args.model_type: + pipe = WanSpeechToVideoPipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ) else: pipe = WanPipeline( transformer=transformer, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a606941f1d7a..edafa4f6099e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -243,6 +243,7 @@ "UNetSpatioTemporalConditionModel", "UVit2DModel", "VQModel", + "WanS2VTransformer3DModel", "WanTransformer3DModel", "WanVACETransformer3DModel", "attention_backend", @@ -593,6 +594,7 @@ "VQDiffusionPipeline", "WanImageToVideoPipeline", "WanPipeline", + "WanSpeechToVideoPipeline", "WanVACEPipeline", "WanVideoToVideoPipeline", "WuerstchenCombinedPipeline", @@ -912,6 +914,7 @@ UNetSpatioTemporalConditionModel, UVit2DModel, VQModel, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, attention_backend, @@ -1232,6 +1235,7 @@ VQDiffusionPipeline, WanImageToVideoPipeline, WanPipeline, + WanSpeechToVideoPipeline, WanVACEPipeline, WanVideoToVideoPipeline, WuerstchenCombinedPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 49ac2a1c56fd..19aba5031267 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -99,6 +99,7 @@ _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] + _import_structure["transformers.transformer_wan_s2v"] = ["WanS2VTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] @@ -196,6 +197,7 @@ T5FilmDecoder, Transformer2DModel, TransformerTemporalModel, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..ce7fbce40b87 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -36,4 +36,5 @@ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel + from .transformer_wan_s2v import WanS2VTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 13f0c152c1aa..f2bdf9f265cf 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -201,7 +201,6 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b3cfc6228736..43267e88d6b5 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -380,7 +380,13 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", ] - _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"] + _import_structure["wan"] = [ + "WanPipeline", + "WanImageToVideoPipeline", + "WanVideoToVideoPipeline", + "WanVACEPipeline", + "WanSpeechToVideoPipeline", + ] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -778,7 +784,13 @@ UniDiffuserTextDecoder, ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline - from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline + from .wan import ( + WanImageToVideoPipeline, + WanPipeline, + WanSpeechToVideoPipeline, + WanVACEPipeline, + WanVideoToVideoPipeline, + ) from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py index bb96372b1db2..f21a66dbb7e6 100644 --- a/src/diffusers/pipelines/wan/__init__.py +++ b/src/diffusers/pipelines/wan/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_wan"] = ["WanPipeline"] _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"] + _import_structure["pipeline_wan_s2v"] = ["WanSpeechToVideoPipeline"] _import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"] _import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -36,6 +37,7 @@ else: from .pipeline_wan import WanPipeline from .pipeline_wan_i2v import WanImageToVideoPipeline + from .pipeline_wan_s2v import WanSpeechToVideoPipeline from .pipeline_wan_vace import WanVACEPipeline from .pipeline_wan_video2video import WanVideoToVideoPipeline diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 68e2baf81a10..69f371f440ef 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -176,7 +176,6 @@ def __init__( transformer=transformer, scheduler=scheduler, image_processor=image_processor, - audio_encoder=audio_encoder ) self.register_to_config(expand_timesteps=expand_timesteps) From bbe282f68edc13aafc54563206c1e7a457a11cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 29 Aug 2025 15:28:30 +0300 Subject: [PATCH 004/131] fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bbb971249604..6a259dddee03 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1443,6 +1443,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class WanS2VTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class WanTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 22dfc5fccae1..c7345153545a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -3242,6 +3242,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class WanSpeechToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class WanVACEPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 41fba83f329967e7c95014574e33ff1b2a1c6e36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 29 Aug 2025 15:56:39 +0300 Subject: [PATCH 005/131] upp --- scripts/convert_wan_to_diffusers.py | 3 + src/diffusers/pipelines/wan/__init__.py | 2 + src/diffusers/pipelines/wan/audio_encoder.py | 169 ++++++++++++++++++ .../pipelines/wan/pipeline_wan_s2v.py | 3 + 4 files changed, 177 insertions(+) create mode 100644 src/diffusers/pipelines/wan/audio_encoder.py diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index ab237a2a3b45..1ad7d731e0f6 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -19,6 +19,7 @@ WanVACEPipeline, WanVACETransformer3DModel, ) +from diffusers.pipelines.wan.audio_encoder import WanAudioEncoder TRANSFORMER_KEYS_RENAME_DICT = { @@ -1009,12 +1010,14 @@ def get_args(): scheduler=scheduler, ) elif "S2V" in args.model_type: + audio_encoder = WanAudioEncoder.from_pretrained("facebook/wav2vec2-base-960h") pipe = WanSpeechToVideoPipeline( transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, + audio_encoder=audio_encoder, ) else: pipe = WanPipeline( diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py index f21a66dbb7e6..adb716943359 100644 --- a/src/diffusers/pipelines/wan/__init__.py +++ b/src/diffusers/pipelines/wan/__init__.py @@ -22,6 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["audio_encoder"] = ["WanAudioEncoder"] _import_structure["pipeline_wan"] = ["WanPipeline"] _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"] _import_structure["pipeline_wan_s2v"] = ["WanSpeechToVideoPipeline"] @@ -35,6 +36,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .audio_encoder import WanAudioEncoder from .pipeline_wan import WanPipeline from .pipeline_wan_i2v import WanImageToVideoPipeline from .pipeline_wan_s2v import WanSpeechToVideoPipeline diff --git a/src/diffusers/pipelines/wan/audio_encoder.py b/src/diffusers/pipelines/wan/audio_encoder.py new file mode 100644 index 000000000000..3150f450d9c7 --- /dev/null +++ b/src/diffusers/pipelines/wan/audio_encoder.py @@ -0,0 +1,169 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import librosa +import numpy as np +import torch +import torch.nn.functional as F +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + + +def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if fixed_start is not None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] input_fps: fps for audio, f_a output_fps: fps for video, f_m output_len: video length + """ + features = features.transpose(1, 2) # [1, 512, T] + seq_len = features.shape[2] / float(input_fps) # T/f_a + if output_len is None: + output_len = int(seq_len * output_fps) # f_m*T/f_a + output_features = F.interpolate( + features, size=output_len, align_corners=True, mode="linear" + ) # [1, 512, output_len] + return output_features.transpose(1, 2) # [1, output_len, 512] + + +class WanAudioEncoder: + def __init__(self, device="cpu", model_id="facebook/wav2vec2-base-960h"): + # load pretrained model + self.processor = Wav2Vec2Processor.from_pretrained(model_id) + self.model = Wav2Vec2ForCTC.from_pretrained(model_id) + + self.model = self.model.to(device) + + self.video_rate = 30 + + def extract_audio_feat(self, audio_path, return_all_layers=False, dtype=torch.float32): + audio_input, sample_rate = librosa.load(audio_path, sr=16000) + + input_values = self.processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values + + # INFERENCE + + # retrieve logits & take argmax + res = self.model(input_values.to(self.model.device), output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) + + z = feat.to(dtype) # Encoding for the motion + return z + + def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 + + bucket_num = min_batch_num * batch_frames + batch_idx = [stride * i for i in range(bucket_num)] + batch_audio_eb = [] + for bi in batch_idx: + if bi < audio_frame_num: + audio_sample_stride = 2 + chosen_idx = list( + range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride) + ) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = ( + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) + if not return_all_layers + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + ) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = self.video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=self.video_rate, + total_frames=audio_frame_num + padd_audio_num, + target_fps=fps, + num_sample=bucket_num, + fixed_start=0, + ) + batch_audio_eb = [] + audio_sample_stride = int(self.video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + chosen_idx = list( + range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride) + ) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = ( + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) + if not return_all_layers + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + ) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 69f371f440ef..df0c78f72b38 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -29,6 +29,7 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline +from .audio_encoder import WanAudioEncoder from .pipeline_output import WanPipelineOutput @@ -164,6 +165,7 @@ def __init__( image_processor: CLIPImageProcessor = None, image_encoder: CLIPVisionModel = None, transformer: WanS2VTransformer3DModel = None, + audio_encoder: WanAudioEncoder = None, expand_timesteps: bool = False, ): super().__init__() @@ -176,6 +178,7 @@ def __init__( transformer=transformer, scheduler=scheduler, image_processor=image_processor, + audio_encoder=audio_encoder, ) self.register_to_config(expand_timesteps=expand_timesteps) From 1a0059f5ab21a193ff60a5a16becfd0da5008839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 29 Aug 2025 18:00:03 +0300 Subject: [PATCH 006/131] Refactor WanSpeechToVideoPipeline: remove unused image encoder and update example imports Add unit tests for WanSpeechToVideoPipeline and WanS2VTransformer3DModel and gguf --- .../pipelines/wan/pipeline_wan_s2v.py | 81 +------ .../pipelines/wan/test_wan_speech_to_video.py | 216 ++++++++++++++++++ tests/quantization/gguf/test_gguf.py | 28 +++ 3 files changed, 256 insertions(+), 69 deletions(-) create mode 100644 tests/pipelines/wan/test_wan_speech_to_video.py diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index df0c78f72b38..3268de4895a0 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -18,7 +18,7 @@ import PIL import regex as re import torch -from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel +from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput @@ -50,19 +50,13 @@ ```python >>> import torch >>> import numpy as np - >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline + >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline >>> from diffusers.utils import export_to_video, load_image - >>> from transformers import CLIPVisionModel - >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers - >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" - >>> image_encoder = CLIPVisionModel.from_pretrained( - ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 - ... ) + >>> # Available models: Wan-AI/Wan2.2-S2V-14B-Diffusers + >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) - >>> pipe = WanImageToVideoPipeline.from_pretrained( - ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 - ... ) + >>> pipe = WanSpeechToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> image = load_image( @@ -139,11 +133,6 @@ class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - image_encoder ([`CLIPVisionModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically - the - [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) - variant. transformer ([`WanTransformer3DModel`]): Conditional Transformer to denoise the input latents. scheduler ([`UniPCMultistepScheduler`]): @@ -152,7 +141,7 @@ class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - model_cpu_offload_seq = "text_encoder->image_encoder->audio_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->audio_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _optional_components = ["transformer", "image_encoder", "image_processor"] @@ -162,8 +151,6 @@ def __init__( text_encoder: UMT5EncoderModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, - image_processor: CLIPImageProcessor = None, - image_encoder: CLIPVisionModel = None, transformer: WanS2VTransformer3DModel = None, audio_encoder: WanAudioEncoder = None, expand_timesteps: bool = False, @@ -174,10 +161,8 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - image_encoder=image_encoder, transformer=transformer, scheduler=scheduler, - image_processor=image_processor, audio_encoder=audio_encoder, ) self.register_to_config(expand_timesteps=expand_timesteps) @@ -185,7 +170,6 @@ def __init__( self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - self.image_processor = image_processor def _get_t5_prompt_embeds( self, @@ -375,12 +359,6 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - if self.config.boundary_ratio is None and guidance_scale_2 is not None: - raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") - - if self.config.boundary_ratio is not None and image_embeds is not None: - raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") - def prepare_latents( self, image: PipelineImageInput, @@ -552,10 +530,6 @@ def __call__( of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - guidance_scale_2 (`float`, *optional*, defaults to `None`): - Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's - `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` - and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -628,9 +602,6 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) - if self.config.boundary_ratio is not None and guidance_scale_2 is None: - guidance_scale_2 = guidance_scale - self._guidance_scale = guidance_scale self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs @@ -660,21 +631,11 @@ def __call__( ) # Encode image embedding - transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - # only wan 2.1 i2v transformer accepts image_embeds - if self.transformer is not None and self.transformer.config.image_dim is not None: - if image_embeds is None: - if last_image is None: - image_embeds = self.encode_image(image, device) - else: - image_embeds = self.encode_image([image, last_image], device) - image_embeds = image_embeds.repeat(batch_size, 1, 1) - image_embeds = image_embeds.to(transformer_dtype) - # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -682,10 +643,6 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - if last_image is not None: - last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( - device, dtype=torch.float32 - ) latents_outputs = self.prepare_latents( image, @@ -710,11 +667,6 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - if self.config.boundary_ratio is not None: - boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps - else: - boundary_timestep = None - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -722,15 +674,6 @@ def __call__( self._current_timestep = t - if boundary_timestep is None or t >= boundary_timestep: - # wan2.1 or high-noise stage in wan2.2 - current_model = self.transformer - current_guidance_scale = guidance_scale - else: - # low-noise stage in wan2.2 - current_model = self.transformer_2 - current_guidance_scale = guidance_scale_2 - if self.config.expand_timesteps: latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents latent_model_input = latent_model_input.to(transformer_dtype) @@ -743,8 +686,8 @@ def __call__( latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - with current_model.cache_context("cond"): - noise_pred = current_model( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, @@ -754,8 +697,8 @@ def __call__( )[0] if self.do_classifier_free_guidance: - with current_model.cache_context("uncond"): - noise_uncond = current_model( + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, @@ -763,7 +706,7 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py new file mode 100644 index 000000000000..3bbe673d2de4 --- /dev/null +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -0,0 +1,216 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanS2VTransformer3DModel, + WanSpeechToVideoPipeline, +) + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class WanSpeechToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanSpeechToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanS2VTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=3, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + num_frames = 17 + height = 16 + width = 16 + + video = [Image.new("RGB", (height, width))] * num_frames + mask = [Image.new("L", (height, width), 0)] * num_frames + + inputs = { + "video": video, + "mask": mask, + "prompt": "dance monkey", + "negative_prompt": "negative", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": num_frames, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + # fmt: off + expected_slice = [0.4523, 0.45198, 0.44872, 0.45326, 0.45211, 0.45258, 0.45344, 0.453, 0.52431, 0.52572, 0.50701, 0.5118, 0.53717, 0.53093, 0.50557, 0.51402] + # fmt: on + + video_slice = video.flatten() + video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) + video_slice = [round(x, 5) for x in video_slice.tolist()] + self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + + def test_inference_with_single_reference_image(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["reference_images"] = Image.new("RGB", (16, 16)) + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + # fmt: off + expected_slice = [0.45247, 0.45214, 0.44874, 0.45314, 0.45171, 0.45299, 0.45428, 0.45317, 0.51378, 0.52658, 0.53361, 0.52303, 0.46204, 0.50435, 0.52555, 0.51342] + # fmt: on + + video_slice = video.flatten() + video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) + video_slice = [round(x, 5) for x in video_slice.tolist()] + self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + + def test_inference_with_multiple_reference_image(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["reference_images"] = [[Image.new("RGB", (16, 16))] * 2] + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + # fmt: off + expected_slice = [0.45321, 0.45221, 0.44818, 0.45375, 0.45268, 0.4519, 0.45271, 0.45253, 0.51244, 0.52223, 0.51253, 0.51321, 0.50743, 0.51177, 0.51626, 0.50983] + # fmt: on + + video_slice = video.flatten() + video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) + video_slice = [round(x, 5) for x in video_slice.tolist()] + self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip("Batching is not yet supported with this pipeline") + def test_inference_batch_consistent(self): + pass + + @unittest.skip("Batching is not yet supported with this pipeline") + def test_inference_batch_single_identical(self): + return super().test_inference_batch_single_identical() + + @unittest.skip( + "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" + ) + def test_float16_inference(self): + pass + + @unittest.skip( + "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" + ) + def test_save_load_float16(self): + pass diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 38322459e761..442d236438fb 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -16,6 +16,7 @@ HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) @@ -721,6 +722,33 @@ def get_dummy_inputs(self): } +class WanS2VGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-S2V-14B-GGUF/blob/main/Wan2.2-S2V-14B-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanS2VTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states": torch.randn( + (1, 96, 2, 64, 64), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states_scale": torch.randn( + (8,), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + @require_torch_version_greater("2.7.1") class GGUFCompileTests(QuantCompileTests, unittest.TestCase): torch_dtype = torch.bfloat16 From 44f4866d1f707b6bf709fc267d879131add94a56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 29 Aug 2025 19:43:43 +0300 Subject: [PATCH 007/131] encoding image to audio --- src/diffusers/audio_processor.py | 76 ++++++++++++ .../pipelines/wan/pipeline_wan_s2v.py | 109 ++++++++---------- 2 files changed, 125 insertions(+), 60 deletions(-) create mode 100644 src/diffusers/audio_processor.py diff --git a/src/diffusers/audio_processor.py b/src/diffusers/audio_processor.py new file mode 100644 index 000000000000..31a1b09674be --- /dev/null +++ b/src/diffusers/audio_processor.py @@ -0,0 +1,76 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import numpy as np +import PIL.Image +import torch + + +PipelineAudioInput = Union[ + PIL.Image.Image, # librosa? + np.ndarray, + torch.Tensor, + List[PIL.Image.Image], # librosa? + List[np.ndarray], + List[torch.Tensor], +] + + +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 3268de4895a0..e17d103350e5 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -20,6 +20,7 @@ import torch from transformers import AutoTokenizer, UMT5EncoderModel +from ...audio_processor import PipelineAudioInput from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin @@ -51,7 +52,7 @@ >>> import torch >>> import numpy as np >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline - >>> from diffusers.utils import export_to_video, load_image + >>> from diffusers.utils import export_to_video, load_image, load_audio >>> # Available models: Wan-AI/Wan2.2-S2V-14B-Diffusers >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" @@ -73,10 +74,12 @@ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ... ) >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + >>> audio = load_audio(...) >>> output = pipe( - ... image=image, ... prompt=prompt, + ... image=image, + ... audio=audio, ... negative_prompt=negative_prompt, ... height=height, ... width=width, @@ -121,7 +124,7 @@ def retrieve_latents( class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" - Pipeline for image-and-sound-to-video generation using Wan. + Pipeline for prompt-image-sound-to-video generation using Wan-T2V. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -133,12 +136,14 @@ class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanTransformer3DModel`]): + transformer ([`WanT2VTransformer3DModel`]): Conditional Transformer to denoise the input latents. scheduler ([`UniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + audio_encoder ([`WanAudioEncoder`]): + Audio Encoder to process audio inputs. """ model_cpu_offload_seq = "text_encoder->audio_encoder->transformer->vae" @@ -153,7 +158,6 @@ def __init__( scheduler: FlowMatchEulerDiscreteScheduler, transformer: WanS2VTransformer3DModel = None, audio_encoder: WanAudioEncoder = None, - expand_timesteps: bool = False, ): super().__init__() @@ -165,7 +169,6 @@ def __init__( scheduler=scheduler, audio_encoder=audio_encoder, ) - self.register_to_config(expand_timesteps=expand_timesteps) self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 @@ -212,15 +215,15 @@ def _get_t5_prompt_embeds( return prompt_embeds - def encode_image( + def encode_audio( self, - image: PipelineImageInput, + audio: PipelineAudioInput, device: Optional[torch.device] = None, ): device = device or self._execution_device - image = self.image_processor(images=image, return_tensors="pt").to(device) - image_embeds = self.image_encoder(**image, output_hidden_states=True) - return image_embeds.hidden_states[-2] + # audio = self.audio_processor(audios=audio, return_tensors="pt").to(device) + audio_embeds = self.audio_encoder(**audio, output_hidden_states=True) + return audio_embeds.hidden_states[-2] # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( @@ -315,7 +318,8 @@ def check_inputs( negative_prompt_embeds=None, image_embeds=None, callback_on_step_end_tensor_inputs=None, - guidance_scale_2=None, + audio=None, + audio_embeds=None, ): if image is not None and image_embeds is not None: raise ValueError( @@ -358,6 +362,17 @@ def check_inputs( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if audio is not None and audio_embeds is not None: + raise ValueError( + f"Cannot forward both `audio`: {audio} and `audio_embeds`: {audio_embeds}. Please make sure to" + " only forward one of the two." + ) + elif audio is None and audio_embeds is None: + raise ValueError( + "Provide either `audio` or `audio_embeds`. Cannot leave both `audio` and `audio_embeds` undefined." + ) + elif audio is not None and not isinstance(audio, (torch.Tensor, list)): + raise ValueError(f"`audio` has to be of type `torch.Tensor` or `list` but is {type(audio)}") def prepare_latents( self, @@ -371,7 +386,6 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - last_image: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial @@ -391,19 +405,10 @@ def prepare_latents( image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] - if self.config.expand_timesteps: - video_condition = image + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) - elif last_image is None: - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 - ) - else: - last_image = last_image.unsqueeze(2) - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], - dim=2, - ) video_condition = video_condition.to(device=device, dtype=self.vae.dtype) latents_mean = ( @@ -427,19 +432,9 @@ def prepare_latents( latent_condition = latent_condition.to(dtype) latent_condition = (latent_condition - latents_mean) * latents_std - if self.config.expand_timesteps: - first_frame_mask = torch.ones( - 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device - ) - first_frame_mask[:, :, 0] = 0 - return latents, latent_condition, first_frame_mask - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - if last_image is None: - mask_lat_size[:, :, list(range(1, num_frames))] = 0 - else: - mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + mask_lat_size[:, :, list(range(1, num_frames))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) @@ -480,19 +475,19 @@ def __call__( image: PipelineImageInput, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, + audio: PipelineAudioInput = None, height: int = 480, width: int = 832, num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, - guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, - last_image: Optional[torch.Tensor] = None, + audio_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -515,6 +510,8 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + audio (`PipelineAudioInput`, *optional*): + The audio input to condition the generation on. Must be an audio, a list of audios or a `torch.Tensor`. height (`int`, defaults to `480`): The height of the generated video. width (`int`, defaults to `832`): @@ -548,6 +545,9 @@ def __call__( image_embeds (`torch.Tensor`, *optional*): Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, image embeddings are generated from the `image` input argument. + audio_embeds (`torch.Tensor`, *optional*): + Pre-generated audio embeddings. Can be used to easily tweak audio inputs (weighting). If not provided, + audio embeddings are generated from the `audio` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -592,7 +592,8 @@ def __call__( negative_prompt_embeds, image_embeds, callback_on_step_end_tensor_inputs, - guidance_scale_2, + audio, + audio_embeds, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -603,7 +604,6 @@ def __call__( num_frames = max(num_frames, 1) self._guidance_scale = guidance_scale - self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -636,6 +636,11 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + if audio_embeds is None: + audio_embeds = self.encode_audio(audio, device) + audio_embeds = audio_embeds.repeat(batch_size, 1, 1) + audio_embeds = audio_embeds.to(transformer_dtype) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -655,13 +660,9 @@ def __call__( device, generator, latents, - last_image, ) - if self.config.expand_timesteps: - # wan 2.2 5b i2v use firt_frame_mask to mask timesteps - latents, condition, first_frame_mask = latents_outputs - else: - latents, condition = latents_outputs + + latents, condition = latents_outputs # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -674,17 +675,8 @@ def __call__( self._current_timestep = t - if self.config.expand_timesteps: - latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents - latent_model_input = latent_model_input.to(transformer_dtype) - - # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size) - temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() - # batch_size, seq_len - timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) - else: - latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) - timestep = t.expand(latents.shape[0]) + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) with self.transformer.cache_context("cond"): noise_pred = self.transformer( @@ -730,9 +722,6 @@ def __call__( self._current_timestep = None - if self.config.expand_timesteps: - latents = (1 - first_frame_mask) * condition + first_frame_mask * latents - if not output_type == "latent": latents = latents.to(self.vae.dtype) latents_mean = ( From 933b618efe22e0a68781a8f0395e4ab024d332b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 30 Aug 2025 10:42:01 +0300 Subject: [PATCH 008/131] Refactor Wan Speech-to-Video audio encoding The previous audio encoding logic was a placeholder. It is now replaced with a `Wav2Vec2ForCTC` model and processor, including the full implementation for processing audio inputs. This involves resampling and aligning audio features with video frames to ensure proper synchronization. Additionally, utility functions for loading audio from files or URLs are added, and the `audio_processor` module is refactored to correctly handle audio data types instead of image types. --- src/diffusers/audio_processor.py | 43 +++--- .../pipelines/wan/pipeline_wan_s2v.py | 129 ++++++++++++++++-- src/diffusers/utils/loading_utils.py | 47 +++++++ 3 files changed, 185 insertions(+), 34 deletions(-) diff --git a/src/diffusers/audio_processor.py b/src/diffusers/audio_processor.py index 31a1b09674be..8957aa97ded3 100644 --- a/src/diffusers/audio_processor.py +++ b/src/diffusers/audio_processor.py @@ -15,62 +15,57 @@ from typing import List, Union import numpy as np -import PIL.Image import torch PipelineAudioInput = Union[ - PIL.Image.Image, # librosa? np.ndarray, torch.Tensor, - List[PIL.Image.Image], # librosa? List[np.ndarray], List[torch.Tensor], ] -def is_valid_image(image) -> bool: +def is_valid_audio(audio) -> bool: r""" - Checks if the input is a valid image. + Checks if the input is a valid audio. - A valid image can be: - - A `PIL.Image.Image`. + A valid audio can be: - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). Args: - image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): - The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + audio (`Union[np.ndarray, torch.Tensor]`): + The audio to validate. It can be a NumPy array or a torch tensor. Returns: `bool`: - `True` if the input is a valid image, `False` otherwise. + `True` if the input is a valid audio, `False` otherwise. """ - return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + return isinstance(audio, (np.ndarray, torch.Tensor)) and audio.ndim in (2, 3) -def is_valid_image_imagelist(images): +def is_valid_audio_audiolist(audios): r""" - Checks if the input is a valid image or list of images. + Checks if the input is a valid audio or list of audios. The input can be one of the following formats: - - A 4D tensor or numpy array (batch of images). - - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or - `torch.Tensor`. - - A list of valid images. + - A 4D tensor or numpy array (batch of audios). + - A valid single audio: 2D `np.ndarray` or `torch.Tensor` (grayscale audio), 3D `np.ndarray` or `torch.Tensor`. + - A list of valid audios. Args: - images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): - The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid - images. + audios (`Union[np.ndarray, torch.Tensor, List]`): + The audio(s) to check. Can be a batch of audios (4D tensor/array), a single audio, or a list of valid + audios. Returns: `bool`: `True` if the input is valid, `False` otherwise. """ - if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + if isinstance(audios, (np.ndarray, torch.Tensor)) and audios.ndim == 4: return True - elif is_valid_image(images): + elif is_valid_audio(audios): return True - elif isinstance(images, list): - return all(is_valid_image(image) for image in images) + elif isinstance(audios, list): + return all(is_valid_audio(audio) for audio in audios) return False diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index e17d103350e5..ab0a4f91d431 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -13,12 +13,15 @@ # limitations under the License. import html +import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import PIL import regex as re import torch -from transformers import AutoTokenizer, UMT5EncoderModel +import torch.nn.functional as F +from transformers import AutoTokenizer, UMT5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor from ...audio_processor import PipelineAudioInput from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -30,7 +33,6 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .audio_encoder import WanAudioEncoder from .pipeline_output import WanPipelineOutput @@ -108,6 +110,42 @@ def prompt_clean(text): return text +def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if fixed_start is not None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] input_fps: fps for audio, f_a output_fps: fps for video, f_m output_len: video length + """ + features = features.transpose(1, 2) # [1, 512, T] + seq_len = features.shape[2] / float(input_fps) # T/f_a + output_len = int(seq_len * output_fps) # f_m*T/f_a + output_features = F.interpolate( + features, size=output_len, align_corners=True, mode="linear" + ) # [1, 512, output_len] + return output_features.transpose(1, 2) # [1, output_len, 512] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -124,7 +162,7 @@ def retrieve_latents( class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" - Pipeline for prompt-image-sound-to-video generation using Wan-T2V. + Pipeline for prompt-image-audio-to-video generation using Wan2.2-S2V. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -142,13 +180,14 @@ class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - audio_encoder ([`WanAudioEncoder`]): + audio_encoder ([`Wav2Vec2ForCTC`]): Audio Encoder to process audio inputs. + audio_processor ([`Wav2Vec2Processor`]): + Audio Processor to preprocess audio inputs. """ model_cpu_offload_seq = "text_encoder->audio_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - _optional_components = ["transformer", "image_encoder", "image_processor"] def __init__( self, @@ -156,8 +195,9 @@ def __init__( text_encoder: UMT5EncoderModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, - transformer: WanS2VTransformer3DModel = None, - audio_encoder: WanAudioEncoder = None, + transformer: WanS2VTransformer3DModel, + audio_encoder: Wav2Vec2ForCTC, + audio_processor: Wav2Vec2Processor, ): super().__init__() @@ -168,11 +208,13 @@ def __init__( transformer=transformer, scheduler=scheduler, audio_encoder=audio_encoder, + audio_processor=audio_processor, ) self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.audio_processor = audio_processor def _get_t5_prompt_embeds( self, @@ -218,12 +260,79 @@ def _get_t5_prompt_embeds( def encode_audio( self, audio: PipelineAudioInput, + sampling_rate: int, + infer_frames: int, + fps: int = 16, device: Optional[torch.device] = None, ): device = device or self._execution_device - # audio = self.audio_processor(audios=audio, return_tensors="pt").to(device) - audio_embeds = self.audio_encoder(**audio, output_hidden_states=True) - return audio_embeds.hidden_states[-2] + video_rate = 30 + audio_sample_m = 0 + + input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values + + # retrieve logits & take argmax + res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True) + feat = torch.cat(res.hidden_states) + + feat = linear_interpolation(feat, input_fps=50, output_fps=30) + + audio_embed = feat.to(torch.float32) # Encoding for the motion + + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = video_rate / fps + + num_repeat = int(audio_frame_num / (infer_frames * scale)) + 1 + + bucket_num = num_repeat * infer_frames + padd_audio_num = math.ceil(num_repeat * infer_frames / fps * video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=video_rate, + total_frames=audio_frame_num + padd_audio_num, + target_fps=fps, + num_sample=bucket_num, + fixed_start=0, + ) + batch_audio_eb = [] + audio_sample_stride = int(self.video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + chosen_idx = list( + range( + bi - audio_sample_m * audio_sample_stride, + bi + (audio_sample_m + 1) * audio_sample_stride, + audio_sample_stride, + ) + ) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = ( + torch.zeros([audio_dim * (2 * audio_sample_m + 1)], device=audio_embed.device) + if not return_all_layers + else torch.zeros([num_layers, audio_dim * (2 * audio_sample_m + 1)], device=audio_embed.device) + ) + batch_audio_eb.append(frame_audio_embed) + audio_embed_bucket = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + audio_embed_bucket = audio_embed_bucket.to(self.device, self.param_dtype) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + return audio_embed_bucket, num_repeat # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index dd23ae73c861..31aec907d712 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -3,6 +3,8 @@ from typing import Any, Callable, List, Optional, Tuple, Union from urllib.parse import unquote, urlparse +import librosa +import numpy import PIL.Image import PIL.ImageOps import requests @@ -138,6 +140,51 @@ def load_video( return pil_images +def load_audio( + audio: Union[str, numpy.ndarray], convert_method: Optional[Callable[[numpy.ndarray], numpy.ndarray]] = None +) -> numpy.ndarray: + """ + Loads `audio` to a numpy array. + + Args: + audio (`str` or `numpy.ndarray`): + The audio to convert to the numpy array format. + convert_method (Callable[[numpy.ndarray], numpy.ndarray], *optional*): + A conversion method to apply to the audio after loading it. When set to `None` the audio will be converted + to a specific format. + + Returns: + `numpy.ndarray`: + A Librosa audio object. + `int`: + The sample rate of the audio. + """ + if isinstance(audio, str): + if audio.startswith("http://") or audio.startswith("https://"): + audio = PIL.Image.open(requests.get(audio, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw) + elif os.path.isfile(audio): + audio, sample_rate = librosa.load(audio, sr=16000) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {audio} is not a valid path." + ) + elif isinstance(audio, numpy.ndarray): + audio = audio + else: + raise ValueError( + "Incorrect format used for the audio. Should be a URL linking to an audio, a local path, or a PIL audio." + ) + + # audio = PIL.ImageOps.exif_transpose(audio) + + if convert_method is not None: + audio = convert_method(audio) + else: + audio = audio.convert("RGB") + + return audio, sample_rate + + # Taken from `transformers`. def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: if "." in tensor_name: From 6d55c939a1d86aade30e0fbc0f9ca74c65fce319 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 30 Aug 2025 12:11:38 +0300 Subject: [PATCH 009/131] up --- docs/source/en/api/pipelines/wan.md | 94 +++++++++++++++---- scripts/convert_wan_to_diffusers.py | 14 ++- .../pipelines/wan/pipeline_wan_s2v.py | 14 ++- 3 files changed, 96 insertions(+), 26 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 3289a840e2b1..b22f871925a5 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers: - [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) - [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) - [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers) +- [Wan 2.2 S2V 14B](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B-Diffusers) > [!TIP] > Click on the Wan models in the right sidebar for more examples of video generation. @@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained( pipeline.to("cuda") prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards """ @@ -150,15 +151,15 @@ pipeline.transformer = torch.compile( ) prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards """ @@ -236,6 +237,61 @@ export_to_video(output, "output.mp4", fps=16) + +### Wan-S2V: Audio-Driven Cinematic Video Generation + +[Wan-S2V](https://huggingface.co/papers/2508.18621) by the Wan Team. + +*Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.* + +The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, and an audio. + + + + +```python +import numpy as np +import torch +from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline +from diffusers.utils import export_to_video, load_image, load_audio +from transformers import Wav2Vec2ForCTC + + +model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" +audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = WanSpeechToVideoPipeline.from_pretrained( + model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png") +audio = load_audio("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png") +pose_video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_pose_video.mp4") + +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +first_frame, height, width = aspect_ratio_resize(first_frame, pipe) + +prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + +output = pipe( + image=first_frame, audio=audio, prompt=prompt, height=height, width=width, guidance_scale=5.0, + # pose_video=pose_video +).frames[0] +export_to_video(output, "output.mp4", fps=16) +``` + + + + + ### Any-to-Video Controllable Generation Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include: @@ -281,10 +337,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip # use "steamboat willie style" to trigger the LoRA prompt = """ - steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, - revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in - for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. - Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic + steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, + revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in + for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. + Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ @@ -353,6 +409,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip - all - __call__ +## WanSpeechToVideoPipeline + +[[autodoc]] WanSpeechToVideoPipeline + - all + - __call__ + ## WanVideoToVideoPipeline [[autodoc]] WanVideoToVideoPipeline diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 1ad7d731e0f6..c7b78839afa7 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -6,7 +6,7 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import load_file -from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel +from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel, Wav2Vec2Processor, Wav2Vec2ForCTC from diffusers import ( AutoencoderKLWan, @@ -19,7 +19,6 @@ WanVACEPipeline, WanVACETransformer3DModel, ) -from diffusers.pipelines.wan.audio_encoder import WanAudioEncoder TRANSFORMER_KEYS_RENAME_DICT = { @@ -945,7 +944,7 @@ def get_args(): tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") if "FLF2V" in args.model_type: flow_shift = 16.0 - elif "TI2V" in args.model_type: + elif "TI2V" in args.model_type or "S2V" in args.model_type: flow_shift = 5.0 else: flow_shift = 3.0 @@ -1010,7 +1009,8 @@ def get_args(): scheduler=scheduler, ) elif "S2V" in args.model_type: - audio_encoder = WanAudioEncoder.from_pretrained("facebook/wav2vec2-base-960h") + audio_encoder = Wav2Vec2ForCTC.from_pretrained("Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english") + audio_processor = Wav2Vec2Processor.from_pretrained("Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english") pipe = WanSpeechToVideoPipeline( transformer=transformer, text_encoder=text_encoder, @@ -1018,6 +1018,7 @@ def get_args(): vae=vae, scheduler=scheduler, audio_encoder=audio_encoder, + audio_processor=audio_processor, ) else: pipe = WanPipeline( @@ -1028,4 +1029,7 @@ def get_args(): scheduler=scheduler, ) - pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + pipe.save_pretrained(args.output_path, + push_to_hub=True, + safe_serialization=True, + max_shard_size="5GB") diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index ab0a4f91d431..f42dbc471998 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +from PIL import Image import PIL import regex as re import torch @@ -582,9 +583,10 @@ def attention_kwargs(self): def __call__( self, image: PipelineImageInput, - prompt: Union[str, List[str]] = None, + audio: PipelineAudioInput, + prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, - audio: PipelineAudioInput = None, + pose_video: Optional[List[Image.Image]] = None, height: int = 480, width: int = 832, num_frames: int = 81, @@ -612,15 +614,17 @@ def __call__( Args: image (`PipelineImageInput`): The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. - prompt (`str` or `List[str]`, *optional*): + audio (`PipelineAudioInput`): + The audio input to condition the generation on. Must be an audio, a list of audios or a `torch.Tensor`. + prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - audio (`PipelineAudioInput`, *optional*): - The audio input to condition the generation on. Must be an audio, a list of audios or a `torch.Tensor`. + pose_video (`List[Image.Image]`, *optional*): + A list of PIL images representing the pose video to condition the generation on. height (`int`, defaults to `480`): The height of the generated video. width (`int`, defaults to `832`): From e6f6a22b98f3314362f2e978b13d4a7cc05f2a28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 30 Aug 2025 18:19:39 +0300 Subject: [PATCH 010/131] up --- docs/source/en/api/pipelines/wan.md | 5 +- scripts/convert_wan_to_diffusers.py | 22 ++- .../pipelines/wan/pipeline_wan_s2v.py | 146 ++++++++++++++++-- 3 files changed, 152 insertions(+), 21 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index b22f871925a5..3754ecdbcb7c 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -266,7 +266,7 @@ pipe = WanSpeechToVideoPipeline.from_pretrained( pipe.to("cuda") first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png") -audio = load_audio("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png") +audio, sampling_rate = load_audio("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png") pose_video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_pose_video.mp4") def aspect_ratio_resize(image, pipe, max_area=720 * 1280): @@ -282,7 +282,8 @@ first_frame, height, width = aspect_ratio_resize(first_frame, pipe) prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." output = pipe( - image=first_frame, audio=audio, prompt=prompt, height=height, width=width, guidance_scale=5.0, + image=first_frame, audio=audio, sampling_rate=sampling_rate, + prompt=prompt, height=height, width=width, guidance_scale=5.0, # pose_video=pose_video ).frames[0] export_to_video(output, "output.mp4", fps=16) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index c7b78839afa7..613189e30ea7 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -6,7 +6,14 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import load_file -from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel, Wav2Vec2Processor, Wav2Vec2ForCTC +from transformers import ( + AutoProcessor, + AutoTokenizer, + CLIPVisionModelWithProjection, + UMT5EncoderModel, + Wav2Vec2ForCTC, + Wav2Vec2Processor, +) from diffusers import ( AutoencoderKLWan, @@ -1009,8 +1016,12 @@ def get_args(): scheduler=scheduler, ) elif "S2V" in args.model_type: - audio_encoder = Wav2Vec2ForCTC.from_pretrained("Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english") - audio_processor = Wav2Vec2Processor.from_pretrained("Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english") + audio_encoder = Wav2Vec2ForCTC.from_pretrained( + "Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english" + ) + audio_processor = Wav2Vec2Processor.from_pretrained( + "Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english" + ) pipe = WanSpeechToVideoPipeline( transformer=transformer, text_encoder=text_encoder, @@ -1029,7 +1040,4 @@ def get_args(): scheduler=scheduler, ) - pipe.save_pretrained(args.output_path, - push_to_hub=True, - safe_serialization=True, - max_shard_size="5GB") + pipe.save_pretrained(args.output_path, push_to_hub=True, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index f42dbc471998..5661599aa71a 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -17,11 +17,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -from PIL import Image import PIL import regex as re import torch import torch.nn.functional as F +from PIL import Image from transformers import AutoTokenizer, UMT5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor from ...audio_processor import PipelineAudioInput @@ -216,6 +216,8 @@ def __init__( self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.audio_processor = audio_processor + self.motion_frames = 73 + self.drop_first_motion = True def _get_t5_prompt_embeds( self, @@ -262,7 +264,7 @@ def encode_audio( self, audio: PipelineAudioInput, sampling_rate: int, - infer_frames: int, + num_frames: int, fps: int = 16, device: Optional[torch.device] = None, ): @@ -289,10 +291,10 @@ def encode_audio( scale = video_rate / fps - num_repeat = int(audio_frame_num / (infer_frames * scale)) + 1 + num_repeat = int(audio_frame_num / (num_frames * scale)) + 1 - bucket_num = num_repeat * infer_frames - padd_audio_num = math.ceil(num_repeat * infer_frames / fps * video_rate) - audio_frame_num + bucket_num = num_repeat * num_frames + padd_audio_num = math.ceil(num_repeat * num_frames / fps * video_rate) - audio_frame_num batch_idx = get_sample_indices( original_fps=video_rate, total_frames=audio_frame_num + padd_audio_num, @@ -496,6 +498,10 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + init_first_frame: Optional[bool] = False, + pose_video: Optional[List[Image.Image]] = None, + num_repeat: Optional[int] = 1, + sampling_fps: Optional[int] = 16 ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial @@ -508,6 +514,8 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + motion_latents = torch.zeros([1, 3, self.motion_frames, height, width], dtype=dtype, device=device) + if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -542,6 +550,21 @@ def prepare_latents( latent_condition = latent_condition.to(dtype) latent_condition = (latent_condition - latents_mean) * latents_std + # Encode motion latents + videos_last_frames = motion_latents.detach() + if init_first_frame: + self.drop_first_motion = False + motion_latents[:, :, -6:] = video_condition + motion_latents = torch.stack(self.vae.encode(motion_latents)) + + # Get pose condition input if needed + COND = self.load_pose_condition( + pose_video, + num_repeat, + num_frames, + (height, width), + sampling_fps) + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size[:, :, list(range(1, num_frames))] = 0 @@ -552,7 +575,89 @@ def prepare_latents( mask_lat_size = mask_lat_size.transpose(1, 2) mask_lat_size = mask_lat_size.to(latent_condition.device) - return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + return (latents, torch.concat([mask_lat_size, latent_condition], dim=1)), motion_latents + + def load_pose_condition(self, pose_video, num_repeat, infer_frames, size, sampling_fps): + HEIGHT, WIDTH = size + if not pose_video is None: + pose_seq = self.read_last_n_frames( + pose_video, + n_frames=infer_frames * num_repeat, + target_fps=sampling_fps, + reverse=True) + + resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) + crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) + tensor_trans = transforms.ToTensor() + + cond_tensor = torch.from_numpy(pose_seq) + cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0 + cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute( + 1, 0, 2, 3).unsqueeze(0) + + padding_frame_num = num_repeat * infer_frames - cond_tensor.shape[2] + cond_tensor = torch.cat([ + cond_tensor, + - torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH]) + ], + dim=2) + + cond_tensors = torch.chunk(cond_tensor, num_repeat, dim=2) + else: + cond_tensors = [-torch.ones([1, 3, infer_frames, HEIGHT, WIDTH])] + + COND = [] + for r in range(len(cond_tensors)): + cond = cond_tensors[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], + dim=2) + cond_lat = torch.stack( + self.vae.encode( + cond.to(dtype=self.param_dtype, + device=self.device)))[:, :, + 1:].cpu() # for mem save + COND.append(cond_lat) + + return COND + + def read_last_n_frames(self, + video_path, + n_frames, + target_fps=16, + reverse=False): + """ + Read the last `n_frames` from a video at the specified frame rate. + + Parameters: + video_path (str): Path to the video file. + n_frames (int): Number of frames to read. + target_fps (int, optional): Target sampling frame rate. Defaults to 16. + reverse (bool, optional): Whether to read frames in reverse order. + If True, reads the first `n_frames` instead of the last ones. + + Returns: + np.ndarray: A NumPy array of shape [n_frames, H, W, 3], representing the sampled video frames. + """ + vr = VideoReader(video_path) + original_fps = vr.get_avg_fps() + total_frames = len(vr) + + interval = max(1, round(original_fps / target_fps)) + + required_span = (n_frames - 1) * interval + + start_frame = max(0, total_frames - required_span - + 1) if not reverse else 0 + + sampled_indices = [] + for i in range(n_frames): + indice = start_frame + i * interval + if indice >= total_frames: + break + else: + sampled_indices.append(indice) + + return vr.get_batch(sampled_indices).asnumpy() @property def guidance_scale(self): @@ -584,12 +689,13 @@ def __call__( self, image: PipelineImageInput, audio: PipelineAudioInput, + sampling_rate: int, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, pose_video: Optional[List[Image.Image]] = None, height: int = 480, width: int = 832, - num_frames: int = 81, + num_frames: int = 80, num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, @@ -607,6 +713,8 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + init_first_frame: bool = False, + sampling_fps: int = 16, ): r""" The call function to the pipeline for generation. @@ -616,6 +724,8 @@ def __call__( The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. audio (`PipelineAudioInput`): The audio input to condition the generation on. Must be an audio, a list of audios or a `torch.Tensor`. + sampling_rate (`int`): + The sampling rate of the audio input. prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -629,8 +739,8 @@ def __call__( The height of the generated video. width (`int`, defaults to `832`): The width of the generated video. - num_frames (`int`, defaults to `81`): - The number of frames in the generated video. + num_frames (`int`, defaults to `80`): + The number of frames in the generated video. The number should be a multiple of 4. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -681,6 +791,10 @@ def __call__( max_sequence_length (`int`, defaults to `512`): The maximum sequence length of the text encoder. If the prompt is longer than this, it will be truncated. If the prompt is shorter, it will be padded to this length. + init_first_frame (`bool`, *optional*, defaults to False): + Whether to use the reference image as the first frame (i.e., standard image-to-video generation). + sampling_fps (`int`, *optional*, defaults to 16): + The frame rate (in frames per second) at which the generated video will be sampled. Examples: @@ -743,17 +857,21 @@ def __call__( device=device, ) - # Encode image embedding transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) if audio_embeds is None: - audio_embeds = self.encode_audio(audio, device) + audio_embeds, num_repeat = self.encode_audio(audio, sampling_rate, num_frames, sample_fps, device) + # TODO: num_repeat_input vs. num_repeat? + if num_repeat_input is None or num_repeat_input > num_repeat: + num_repeat_input = num_repeat audio_embeds = audio_embeds.repeat(batch_size, 1, 1) audio_embeds = audio_embeds.to(transformer_dtype) + latent_motion_frames = (self.motion_frames + 3) // 4 + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -762,7 +880,7 @@ def __call__( num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - latents_outputs = self.prepare_latents( + latents_outputs, motion_latents = self.prepare_latents( image, batch_size * num_videos_per_prompt, num_channels_latents, @@ -773,6 +891,10 @@ def __call__( device, generator, latents, + init_first_frame, + pose_video, + num_repeat_input, + sampling_fps ) latents, condition = latents_outputs From 313fea5389990c88dab41fd54da2590a5932e783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 30 Aug 2025 20:26:11 +0300 Subject: [PATCH 011/131] up --- src/diffusers/pipelines/wan/audio_encoder.py | 169 ------------------ .../pipelines/wan/pipeline_wan_s2v.py | 8 +- 2 files changed, 4 insertions(+), 173 deletions(-) delete mode 100644 src/diffusers/pipelines/wan/audio_encoder.py diff --git a/src/diffusers/pipelines/wan/audio_encoder.py b/src/diffusers/pipelines/wan/audio_encoder.py deleted file mode 100644 index 3150f450d9c7..000000000000 --- a/src/diffusers/pipelines/wan/audio_encoder.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import librosa -import numpy as np -import torch -import torch.nn.functional as F -from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor - - -def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): - required_duration = num_sample / target_fps - required_origin_frames = int(np.ceil(required_duration * original_fps)) - if required_duration > total_frames / original_fps: - raise ValueError("required_duration must be less than video length") - - if fixed_start is not None and fixed_start >= 0: - start_frame = fixed_start - else: - max_start = total_frames - required_origin_frames - if max_start < 0: - raise ValueError("video length is too short") - start_frame = np.random.randint(0, max_start + 1) - start_time = start_frame / original_fps - - end_time = start_time + required_duration - time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) - - frame_indices = np.round(np.array(time_points) * original_fps).astype(int) - frame_indices = np.clip(frame_indices, 0, total_frames - 1) - return frame_indices - - -def linear_interpolation(features, input_fps, output_fps, output_len=None): - """ - features: shape=[1, T, 512] input_fps: fps for audio, f_a output_fps: fps for video, f_m output_len: video length - """ - features = features.transpose(1, 2) # [1, 512, T] - seq_len = features.shape[2] / float(input_fps) # T/f_a - if output_len is None: - output_len = int(seq_len * output_fps) # f_m*T/f_a - output_features = F.interpolate( - features, size=output_len, align_corners=True, mode="linear" - ) # [1, 512, output_len] - return output_features.transpose(1, 2) # [1, output_len, 512] - - -class WanAudioEncoder: - def __init__(self, device="cpu", model_id="facebook/wav2vec2-base-960h"): - # load pretrained model - self.processor = Wav2Vec2Processor.from_pretrained(model_id) - self.model = Wav2Vec2ForCTC.from_pretrained(model_id) - - self.model = self.model.to(device) - - self.video_rate = 30 - - def extract_audio_feat(self, audio_path, return_all_layers=False, dtype=torch.float32): - audio_input, sample_rate = librosa.load(audio_path, sr=16000) - - input_values = self.processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values - - # INFERENCE - - # retrieve logits & take argmax - res = self.model(input_values.to(self.model.device), output_hidden_states=True) - if return_all_layers: - feat = torch.cat(res.hidden_states) - else: - feat = res.hidden_states[-1] - feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) - - z = feat.to(dtype) # Encoding for the motion - return z - - def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): - num_layers, audio_frame_num, audio_dim = audio_embed.shape - - if num_layers > 1: - return_all_layers = True - else: - return_all_layers = False - - min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 - - bucket_num = min_batch_num * batch_frames - batch_idx = [stride * i for i in range(bucket_num)] - batch_audio_eb = [] - for bi in batch_idx: - if bi < audio_frame_num: - audio_sample_stride = 2 - chosen_idx = list( - range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride) - ) - chosen_idx = [0 if c < 0 else c for c in chosen_idx] - chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] - - if return_all_layers: - frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) - else: - frame_audio_embed = audio_embed[0][chosen_idx].flatten() - else: - frame_audio_embed = ( - torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) - if not return_all_layers - else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) - ) - batch_audio_eb.append(frame_audio_embed) - batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) - - return batch_audio_eb, min_batch_num - - def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): - num_layers, audio_frame_num, audio_dim = audio_embed.shape - - if num_layers > 1: - return_all_layers = True - else: - return_all_layers = False - - scale = self.video_rate / fps - - min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 - - bucket_num = min_batch_num * batch_frames - padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num - batch_idx = get_sample_indices( - original_fps=self.video_rate, - total_frames=audio_frame_num + padd_audio_num, - target_fps=fps, - num_sample=bucket_num, - fixed_start=0, - ) - batch_audio_eb = [] - audio_sample_stride = int(self.video_rate / fps) - for bi in batch_idx: - if bi < audio_frame_num: - chosen_idx = list( - range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride) - ) - chosen_idx = [0 if c < 0 else c for c in chosen_idx] - chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] - - if return_all_layers: - frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) - else: - frame_audio_embed = audio_embed[0][chosen_idx].flatten() - else: - frame_audio_embed = ( - torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) - if not return_all_layers - else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) - ) - batch_audio_eb.append(frame_audio_embed) - batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) - - return batch_audio_eb, min_batch_num diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 5661599aa71a..0c71a6d56db9 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -343,7 +343,7 @@ def encode_prompt( prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, - num_videos_per_prompt: int = 1, + num_videos_per_prompt: Optional[int] = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 226, @@ -442,7 +442,7 @@ def check_inputs( raise ValueError( "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." ) - if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, Image.Image): raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -502,7 +502,7 @@ def prepare_latents( pose_video: Optional[List[Image.Image]] = None, num_repeat: Optional[int] = 1, sampling_fps: Optional[int] = 16 - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor] num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial @@ -863,7 +863,7 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) if audio_embeds is None: - audio_embeds, num_repeat = self.encode_audio(audio, sampling_rate, num_frames, sample_fps, device) + audio_embeds, num_repeat = self.encode_audio(audio, sampling_rate, num_frames, sampling_fps, device) # TODO: num_repeat_input vs. num_repeat? if num_repeat_input is None or num_repeat_input > num_repeat: num_repeat_input = num_repeat From 4ac93391114fd8864c3137bb9947ae2e44995eed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 1 Sep 2025 10:21:47 +0300 Subject: [PATCH 012/131] Improve Wan S2V pipeline --- src/diffusers/audio_processor.py | 2 +- .../pipelines/wan/pipeline_wan_s2v.py | 279 ++++++++++-------- 2 files changed, 155 insertions(+), 126 deletions(-) diff --git a/src/diffusers/audio_processor.py b/src/diffusers/audio_processor.py index 8957aa97ded3..491aacf530aa 100644 --- a/src/diffusers/audio_processor.py +++ b/src/diffusers/audio_processor.py @@ -50,7 +50,7 @@ def is_valid_audio_audiolist(audios): The input can be one of the following formats: - A 4D tensor or numpy array (batch of audios). - - A valid single audio: 2D `np.ndarray` or `torch.Tensor` (grayscale audio), 3D `np.ndarray` or `torch.Tensor`. + - A valid single audio: `np.ndarray` or `torch.Tensor`. - A list of valid audios. Args: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 0c71a6d56db9..306383b4cd82 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -14,14 +14,16 @@ import html import math +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL import regex as re import torch import torch.nn.functional as F +from decord import VideoReader from PIL import Image +from torchvision import transforms from transformers import AutoTokenizer, UMT5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor from ...audio_processor import PipelineAudioInput @@ -29,7 +31,7 @@ from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanS2VTransformer3DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...schedulers import UniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -195,7 +197,7 @@ def __init__( tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, vae: AutoencoderKLWan, - scheduler: FlowMatchEulerDiscreteScheduler, + scheduler: UniPCMultistepScheduler, transformer: WanS2VTransformer3DModel, audio_encoder: Wav2Vec2ForCTC, audio_processor: Wav2Vec2Processor, @@ -343,7 +345,7 @@ def encode_prompt( prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, - num_videos_per_prompt: Optional[int] = 1, + num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 226, @@ -493,17 +495,17 @@ def prepare_latents( num_channels_latents: int = 16, height: int = 480, width: int = 832, - num_frames: int = 81, + num_frames_per_chunk: int = 81, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, init_first_frame: Optional[bool] = False, pose_video: Optional[List[Image.Image]] = None, - num_repeat: Optional[int] = 1, - sampling_fps: Optional[int] = 16 - ) -> Tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor] - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + num_chunks: Optional[int] = 1, + sampling_fps: Optional[int] = 16, + ) -> Tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial @@ -514,7 +516,7 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - motion_latents = torch.zeros([1, 3, self.motion_frames, height, width], dtype=dtype, device=device) + motion_pixels = torch.zeros([1, 3, self.motion_frames, height, width], dtype=dtype, device=device) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -524,7 +526,7 @@ def prepare_latents( image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames_per_chunk - 1, height, width)], dim=2 ) video_condition = video_condition.to(device=device, dtype=self.vae.dtype) @@ -551,80 +553,50 @@ def prepare_latents( latent_condition = (latent_condition - latents_mean) * latents_std # Encode motion latents - videos_last_frames = motion_latents.detach() if init_first_frame: self.drop_first_motion = False - motion_latents[:, :, -6:] = video_condition - motion_latents = torch.stack(self.vae.encode(motion_latents)) + motion_pixels[:, :, -6:] = video_condition + motion_latents = torch.stack(self.vae.encode(motion_pixels)) + videos_last_latents = motion_latents.detach() # Get pose condition input if needed - COND = self.load_pose_condition( - pose_video, - num_repeat, - num_frames, - (height, width), - sampling_fps) - - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + pose_condition = self.load_pose_condition(pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps) - mask_lat_size[:, :, list(range(1, num_frames))] = 0 - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(latent_condition.device) + return (latents, latent_condition), motion_latents, videos_last_latents, pose_condition - return (latents, torch.concat([mask_lat_size, latent_condition], dim=1)), motion_latents - - def load_pose_condition(self, pose_video, num_repeat, infer_frames, size, sampling_fps): + def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size, sampling_fps): HEIGHT, WIDTH = size - if not pose_video is None: + if pose_video is not None: pose_seq = self.read_last_n_frames( - pose_video, - n_frames=infer_frames * num_repeat, - target_fps=sampling_fps, - reverse=True) + pose_video, n_frames=num_frames_per_chunk * num_chunks, target_fps=sampling_fps, reverse=True + ) resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) - tensor_trans = transforms.ToTensor() cond_tensor = torch.from_numpy(pose_seq) cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0 - cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute( - 1, 0, 2, 3).unsqueeze(0) + cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute(1, 0, 2, 3).unsqueeze(0) - padding_frame_num = num_repeat * infer_frames - cond_tensor.shape[2] - cond_tensor = torch.cat([ - cond_tensor, - - torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH]) - ], - dim=2) + padding_frame_num = num_chunks * num_frames_per_chunk - cond_tensor.shape[2] + cond_tensor = torch.cat([cond_tensor, -torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH])], dim=2) - cond_tensors = torch.chunk(cond_tensor, num_repeat, dim=2) + cond_tensors = torch.chunk(cond_tensor, num_chunks, dim=2) else: - cond_tensors = [-torch.ones([1, 3, infer_frames, HEIGHT, WIDTH])] + cond_tensors = [-torch.ones([1, 3, num_frames_per_chunk, HEIGHT, WIDTH])] - COND = [] + pose_condition = [] for r in range(len(cond_tensors)): cond = cond_tensors[r] - cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], - dim=2) - cond_lat = torch.stack( - self.vae.encode( - cond.to(dtype=self.param_dtype, - device=self.device)))[:, :, - 1:].cpu() # for mem save - COND.append(cond_lat) - - return COND - - def read_last_n_frames(self, - video_path, - n_frames, - target_fps=16, - reverse=False): + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_lat = torch.stack(self.vae.encode(cond.to(dtype=self.param_dtype, device=self.device)))[ + :, :, 1: + ].cpu() # for mem save + pose_condition.append(cond_lat) + + return pose_condition + + def read_last_n_frames(self, video_path, n_frames, target_fps=16, reverse=False): """ Read the last `n_frames` from a video at the specified frame rate. @@ -646,8 +618,7 @@ def read_last_n_frames(self, required_span = (n_frames - 1) * interval - start_frame = max(0, total_frames - required_span - - 1) if not reverse else 0 + start_frame = max(0, total_frames - required_span - 1) if not reverse else 0 sampled_indices = [] for i in range(n_frames): @@ -695,7 +666,7 @@ def __call__( pose_video: Optional[List[Image.Image]] = None, height: int = 480, width: int = 832, - num_frames: int = 80, + num_frames_per_chunk: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, @@ -715,6 +686,7 @@ def __call__( max_sequence_length: int = 512, init_first_frame: bool = False, sampling_fps: int = 16, + num_chunks: Optional[int] = None, ): r""" The call function to the pipeline for generation. @@ -739,8 +711,8 @@ def __call__( The height of the generated video. width (`int`, defaults to `832`): The width of the generated video. - num_frames (`int`, defaults to `80`): - The number of frames in the generated video. The number should be a multiple of 4. + num_frames_per_chunk (`int`, defaults to `81`): + The number of frames in each chunk of the generated video. `num_frames_per_chunk` - 1 should be a multiple of 4. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -795,7 +767,9 @@ def __call__( Whether to use the reference image as the first frame (i.e., standard image-to-video generation). sampling_fps (`int`, *optional*, defaults to 16): The frame rate (in frames per second) at which the generated video will be sampled. - + num_chunks (`int`, *optional*, defaults to None): + The number of chunks to process. If not provided, the number of chunks will be + determined by the audio input to generate whole audio. E.g., If the input audio has 4 chunks, then user can set num_chunks=1 to see 1 out of 4 chunks only without generating the whole video. Examples: Returns: @@ -823,12 +797,12 @@ def __call__( audio_embeds, ) - if num_frames % self.vae_scale_factor_temporal != 1: + if num_frames_per_chunk % self.vae_scale_factor_temporal != 1: logger.warning( - f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + f"`num_frames_per_chunk - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - num_frames = max(num_frames, 1) + num_frames_per_chunk = num_frames_per_chunk // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames_per_chunk = max(num_frames_per_chunk, 1) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs @@ -863,10 +837,11 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) if audio_embeds is None: - audio_embeds, num_repeat = self.encode_audio(audio, sampling_rate, num_frames, sampling_fps, device) - # TODO: num_repeat_input vs. num_repeat? - if num_repeat_input is None or num_repeat_input > num_repeat: - num_repeat_input = num_repeat + audio_embeds, num_chunks_audio = self.encode_audio( + audio, sampling_rate, num_frames_per_chunk, sampling_fps, device + ) + if num_chunks is None or num_chunks > num_chunks_audio: + num_chunks = num_chunks_audio audio_embeds = audio_embeds.repeat(batch_size, 1, 1) audio_embeds = audio_embeds.to(transformer_dtype) @@ -880,21 +855,21 @@ def __call__( num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - latents_outputs, motion_latents = self.prepare_latents( + latents_outputs, motion_latents, videos_last_latents, pose_condition = self.prepare_latents( image, batch_size * num_videos_per_prompt, num_channels_latents, height, width, - num_frames, + num_frames_per_chunk, torch.float32, device, generator, latents, init_first_frame, pose_video, - num_repeat_input, - sampling_fps + num_chunks, + sampling_fps, ) latents, condition = latents_outputs @@ -903,57 +878,111 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - - latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) - timestep = t.expand(latents.shape[0]) - - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - with self.transformer.cache_context("uncond"): - noise_uncond = self.transformer( + all_latents = [] + for r in range(num_chunks): + latent_target_frames = (num_frames_per_chunk + 3 + self.motion_frames) // 4 - latent_motion_frames + target_shape = [latent_target_frames, height // 8, width // 8] + max_seq_len = np.prod(target_shape) // 4 + latents = deepcopy(latents) + with torch.no_grad(): + left_idx = r * num_frames_per_chunk + right_idx = r * num_frames_per_chunk + num_frames_per_chunk + pose_latents = pose_condition[r] if pose_video else pose_condition[0] + pose_latents = pose_latents.to(dtype=transformer_dtype, device=device) + audio_embeds_input = audio_embeds[..., left_idx:right_idx] + motion_latents_input = motion_latents.clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, + encoder_hidden_states=prompt_embeds, + motion_latents=motion_latents_input, + image_latents=condition, + pose_latents=pose_latents, + audio_embeds=audio_embeds_input, + motion_frames=[self.motion_frames, latent_motion_frames], + drop_motion_frames=self.drop_first_motion and r == 0, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + motion_latents=motion_latents_input, + image_latents=condition, + audio_embeds=0.0 * audio_embeds_input, + motion_frames=[self.motion_frames, latent_motion_frames], + drop_motion_frames=self.drop_first_motion and r == 0, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not (self.drop_first_motion and r == 0): + decode_latents = torch.cat([motion_latents, latents], dim=2) + else: + decode_latents = torch.cat([condition, latents], dim=2) + + # Work in latent space - no decode-encode cycle + num_latent_frames = (num_frames_per_chunk + 3) // 4 + segment_latents = decode_latents[:, :, -num_latent_frames:] + if self.drop_first_motion and r == 0: + # Adjust for latent space temporal compression + segment_latents = segment_latents[:, :, (3 + 3) // 4 :] + + num_latent_overlap_frames = min(latent_motion_frames, segment_latents.shape[2]) + videos_last_latents = torch.cat( + [ + videos_last_latents[:, :, num_latent_overlap_frames:], + segment_latents[:, :, -num_latent_overlap_frames:], + ], + dim=2, + ) - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # Update motion_latents for next iteration (stay in latent space) + motion_latents = videos_last_latents.to(dtype=motion_latents.dtype, device=motion_latents.device) - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() + # Store latents instead of decoded frames + all_latents.append(segment_latents) - if XLA_AVAILABLE: - xm.mark_step() + # Decode all accumulated latents once at the end + all_latents = torch.cat(all_latents, dim=2) + latents = all_latents self._current_timestep = None From 66ec4fff3299c21029b65fd0d4caf9603a10d6da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 1 Sep 2025 10:22:28 +0300 Subject: [PATCH 013/131] up --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 306383b4cd82..2836abd3e45b 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -560,7 +560,9 @@ def prepare_latents( videos_last_latents = motion_latents.detach() # Get pose condition input if needed - pose_condition = self.load_pose_condition(pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps) + pose_condition = self.load_pose_condition( + pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps + ) return (latents, latent_condition), motion_latents, videos_last_latents, pose_condition @@ -712,7 +714,8 @@ def __call__( width (`int`, defaults to `832`): The width of the generated video. num_frames_per_chunk (`int`, defaults to `81`): - The number of frames in each chunk of the generated video. `num_frames_per_chunk` - 1 should be a multiple of 4. + The number of frames in each chunk of the generated video. `num_frames_per_chunk` - 1 should be a + multiple of 4. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -768,8 +771,9 @@ def __call__( sampling_fps (`int`, *optional*, defaults to 16): The frame rate (in frames per second) at which the generated video will be sampled. num_chunks (`int`, *optional*, defaults to None): - The number of chunks to process. If not provided, the number of chunks will be - determined by the audio input to generate whole audio. E.g., If the input audio has 4 chunks, then user can set num_chunks=1 to see 1 out of 4 chunks only without generating the whole video. + The number of chunks to process. If not provided, the number of chunks will be determined by the audio + input to generate whole audio. E.g., If the input audio has 4 chunks, then user can set num_chunks=1 to + see 1 out of 4 chunks only without generating the whole video. Examples: Returns: @@ -801,7 +805,9 @@ def __call__( logger.warning( f"`num_frames_per_chunk - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) - num_frames_per_chunk = num_frames_per_chunk // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames_per_chunk = ( + num_frames_per_chunk // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + ) num_frames_per_chunk = max(num_frames_per_chunk, 1) self._guidance_scale = guidance_scale From d6ec4654cf841b52177ea9678c2c4da94bb452d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 1 Sep 2025 11:38:34 +0300 Subject: [PATCH 014/131] Refactor latent preparation for S2V --- .../pipelines/wan/pipeline_wan_s2v.py | 151 ++++++++++-------- 1 file changed, 82 insertions(+), 69 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 2836abd3e45b..c05b04531c4c 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -492,6 +492,7 @@ def prepare_latents( self, image: PipelineImageInput, batch_size: int, + latent_motion_frames: int, num_channels_latents: int = 16, height: int = 480, width: int = 832, @@ -500,12 +501,15 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - init_first_frame: Optional[bool] = False, pose_video: Optional[List[Image.Image]] = None, - num_chunks: Optional[int] = 1, - sampling_fps: Optional[int] = 16, - ) -> Tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: - num_latent_frames = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1 + init_first_frame: bool = False, + num_chunks: int = 1, + sampling_fps: int = 16, + transformer_dtype: torch.dtype = torch.bfloat16, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]]: + num_latent_frames = ( + num_frames_per_chunk + 3 + self.motion_frames + ) // self.vae_scale_factor_temporal - latent_motion_frames latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial @@ -516,55 +520,59 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - motion_pixels = torch.zeros([1, 3, self.motion_frames, height, width], dtype=dtype, device=device) - if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device=device, dtype=dtype) - image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + if image is not None: + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames_per_chunk - 1, height, width)], dim=2 - ) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames_per_chunk - 1, height, width)], + dim=2, + ) - video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) - if isinstance(generator, list): - latent_condition = [ - retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator - ] - latent_condition = torch.cat(latent_condition) - else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - - latent_condition = latent_condition.to(dtype) - latent_condition = (latent_condition - latents_mean) * latents_std - - # Encode motion latents - if init_first_frame: - self.drop_first_motion = False - motion_pixels[:, :, -6:] = video_condition - motion_latents = torch.stack(self.vae.encode(motion_pixels)) - videos_last_latents = motion_latents.detach() - - # Get pose condition input if needed - pose_condition = self.load_pose_condition( - pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps - ) + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - return (latents, latent_condition), motion_latents, videos_last_latents, pose_condition + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + motion_pixels = torch.zeros( + [1, 3, self.motion_frames, height, width], dtype=transformer_dtype, device=device + ) + # Get pose condition input if needed + pose_condition = self.load_pose_condition( + pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps + ) + # Encode motion latents + if init_first_frame: + self.drop_first_motion = False + motion_pixels[:, :, -6:] = latent_condition + motion_latents = torch.stack(self.vae.encode(motion_pixels)) + videos_last_latents = motion_latents.detach() + + return latents, latent_condition, videos_last_latents, motion_latents, pose_condition + else: + return latents def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size, sampling_fps): HEIGHT, WIDTH = size @@ -861,31 +869,36 @@ def __call__( num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - latents_outputs, motion_latents, videos_last_latents, pose_condition = self.prepare_latents( - image, - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames_per_chunk, - torch.float32, - device, - generator, - latents, - init_first_frame, - pose_video, - num_chunks, - sampling_fps, - ) - - latents, condition = latents_outputs - # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) all_latents = [] for r in range(num_chunks): + latents_outputs = self.prepare_latents( + image if r == 0 else None, + batch_size * num_videos_per_prompt, + latent_motion_frames, + num_channels_latents, + height, + width, + num_frames_per_chunk, + torch.float32, + device, + generator, + latents, + pose_video, + init_first_frame, + num_chunks, + sampling_fps, + transformer_dtype, + ) + + if r == 0: + latents, condition, videos_last_latents, motion_latents, pose_condition = latents_outputs + else: + latents = latents_outputs + latent_target_frames = (num_frames_per_chunk + 3 + self.motion_frames) // 4 - latent_motion_frames target_shape = [latent_target_frames, height // 8, width // 8] max_seq_len = np.prod(target_shape) // 4 @@ -919,6 +932,7 @@ def __call__( audio_embeds=audio_embeds_input, motion_frames=[self.motion_frames, latent_motion_frames], drop_motion_frames=self.drop_first_motion and r == 0, + max_seq_len=max_seq_len, attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -934,6 +948,7 @@ def __call__( audio_embeds=0.0 * audio_embeds_input, motion_frames=[self.motion_frames, latent_motion_frames], drop_motion_frames=self.drop_first_motion and r == 0, + max_seq_len=max_seq_len, attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -965,11 +980,10 @@ def __call__( decode_latents = torch.cat([condition, latents], dim=2) # Work in latent space - no decode-encode cycle - num_latent_frames = (num_frames_per_chunk + 3) // 4 + num_latent_frames = (num_frames_per_chunk + 3) // self.vae_scale_factor_temporal segment_latents = decode_latents[:, :, -num_latent_frames:] if self.drop_first_motion and r == 0: - # Adjust for latent space temporal compression - segment_latents = segment_latents[:, :, (3 + 3) // 4 :] + segment_latents = segment_latents[:, :, (3 + 3) // self.vae_scale_factor_temporal :] num_latent_overlap_frames = min(latent_motion_frames, segment_latents.shape[2]) videos_last_latents = torch.cat( @@ -980,13 +994,12 @@ def __call__( dim=2, ) - # Update motion_latents for next iteration (stay in latent space) + # Update motion_latents for next iteration motion_latents = videos_last_latents.to(dtype=motion_latents.dtype, device=motion_latents.device) - # Store latents instead of decoded frames + # Accumulate latents so as to decode them all at once at the end all_latents.append(segment_latents) - # Decode all accumulated latents once at the end all_latents = torch.cat(all_latents, dim=2) latents = all_latents From a463c0942633b08b2d39c559d89117bff1c6fee2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 1 Sep 2025 15:45:23 +0300 Subject: [PATCH 015/131] up --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index c05b04531c4c..7399538d7f07 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -932,7 +932,6 @@ def __call__( audio_embeds=audio_embeds_input, motion_frames=[self.motion_frames, latent_motion_frames], drop_motion_frames=self.drop_first_motion and r == 0, - max_seq_len=max_seq_len, attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -948,7 +947,6 @@ def __call__( audio_embeds=0.0 * audio_embeds_input, motion_frames=[self.motion_frames, latent_motion_frames], drop_motion_frames=self.drop_first_motion and r == 0, - max_seq_len=max_seq_len, attention_kwargs=attention_kwargs, return_dict=False, )[0] From 65191a95ee5865e72d48303707d6861765f7394e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 1 Sep 2025 15:46:27 +0300 Subject: [PATCH 016/131] feat: Add audio, pose, and advanced motion conditioning Introduces support for audio and pose conditioning, replacing the previous image conditioning mechanism. The model now accepts audio embeddings and pose latents as input. This change also adds two new, mutually exclusive motion processing modules: - `MotionerTransformers`: A transformer-based module for encoding motion. - `FramePackMotioner`: A module that packs frames from different temporal buckets for motion representation. Additionally, an `AudioInjector` module is implemented to fuse audio features into specific transformer blocks using cross-attention. --- .../transformers/transformer_wan_s2v.py | 800 ++++++++++++++++-- 1 file changed, 746 insertions(+), 54 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index f2bdf9f265cf..9fde11ed895e 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -13,16 +13,18 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -30,13 +32,644 @@ WanAttention, WanAttnProcessor, WanRotaryPosEmbed, - WanTimeTextImageEmbedding, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class CausalConv1d(nn.Module): + + def __init__(self, + chan_in, + chan_out, + kernel_size=3, + stride=1, + dilation=1, + pad_mode='replicate', + **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, + in_dim: int, + hidden_dim: int, + num_heads=int, + need_global=True, + dtype=None, + device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d( + in_dim, hidden_dim // 4 * num_heads, 3, stride=1) + if need_global: + self.conv1_global = CausalConv1d( + in_dim, hidden_dim // 4, 3, stride=1) + self.norm1 = nn.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + + if need_global: + self.final_linear = nn.Linear(hidden_dim, hidden_dim, + **factory_kwargs) + + self.norm1 = nn.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm2 = nn.LayerNorm( + hidden_dim // 2, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm3 = nn.LayerNorm( + hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, + dim=5120, + num_layers=25, + out_dim=2048, + video_rate=8, + num_audio_token=4, + need_global=False): + super().__init__() + self.encoder = MotionEncoder_tc( + in_dim=dim, + hidden_dim=out_dim, + num_heads=num_audio_token, + need_global=need_global) + weight = torch.ones((1, num_layers, 1, 1)) * 0.01 + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + with amp.autocast(dtype=torch.float32): + # features B * num_layers * dim * video_length + weights = self.act(self.weights) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum( + dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + + return res # b f n dim + + +class AudioInjector(nn.Module): + + def __init__(self, + all_modules, + all_modules_names, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + root_net=None, + enable_adain=False, + adain_dim=2048, + need_adain_ont=False): + super().__init__() + num_injector_layers = len(inject_layer) + self.injected_block_id = {} + audio_injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, WanAttentionBlock): + for inject_id in inject_layer: + if f'transformer_blocks.{inject_id}' in mod_name: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([ + AudioCrossAttention( + dim=dim, + num_heads=num_heads, + qk_norm=True, + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_feat = nn.ModuleList([ + nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_vec = nn.ModuleList([ + nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id) + ]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([ + AdaLayerNorm( + output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1) + for _ in range(audio_injector_id) + ]) + if need_adain_ont: + self.injector_adain_output_layers = nn.ModuleList( + [nn.Linear(dim, dim) for _ in range(audio_injector_id)]) + + + +class MotionerTransformers(nn.Module, PeftAdapterMixin): + + def __init__( + self, + patch_size=(1, 2, 2), + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + self_attn_block="SelfAttention", + motion_token_num=1024, + enable_tsm=False, + motion_stride=4, + expand_ratio=2, + trainable_token_pos_emb=False, + ): + super().__init__() + self.patch_size = patch_size + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + self.enable_tsm = enable_tsm + self.motion_stride = motion_stride + self.expand_ratio = expand_ratio + self.sample_c = self.patch_size[0] + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + + # blocks + self.blocks = nn.ModuleList([ + MotionerAttentionBlock( + dim, + ffn_dim, + num_heads, + window_size, + qk_norm, + cross_attn_norm, + eps, + self_attn_block=self_attn_block) for _ in range(num_layers) + ]) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + + self.gradient_checkpointing = False + + self.motion_side_len = int(math.sqrt(motion_token_num)) + assert self.motion_side_len**2 == motion_token_num + self.token = nn.Parameter( + torch.zeros(1, motion_token_num, dim).contiguous()) + + self.trainable_token_pos_emb = trainable_token_pos_emb + if trainable_token_pos_emb: + x = torch.zeros([1, motion_token_num, num_heads, d]) + x[..., ::2] = 1 + + gride_sizes = [[ + torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([1, self.motion_side_len, + self.motion_side_len]).unsqueeze(0).repeat(1, 1), + torch.tensor([1, self.motion_side_len, + self.motion_side_len]).unsqueeze(0).repeat(1, 1), + ]] + token_freqs = rope_apply(x, gride_sizes, self.freqs) + token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2) + token_freqs = token_freqs * 0.01 + self.token_freqs = torch.nn.Parameter(token_freqs) + + def after_patch_embedding(self, x): + return x + + def forward( + self, + x, + ): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + # params + motion_frames = x[0].shape[1] + device = self.patch_embedding.weight.device + freqs = self.freqs + if freqs.device != device: + freqs = freqs.to(device) + + if self.trainable_token_pos_emb: + with amp.autocast(dtype=torch.float64): + token_freqs = self.token_freqs.to(torch.float64) + token_freqs = token_freqs / token_freqs.norm( + dim=-1, keepdim=True) + freqs = [freqs, torch.view_as_complex(token_freqs)] + + if self.enable_tsm: + sample_idx = [ + sample_indices( + u.shape[1], + stride=self.motion_stride, + expand_ratio=self.expand_ratio, + c=self.sample_c) for u in x + ] + x = [ + torch.flip(torch.flip(u, [1])[:, idx], [1]) + for idx, u in zip(sample_idx, x) + ] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + x = self.after_patch_embedding(x) + + seq_f, seq_h, seq_w = x[0].shape[-3:] + batch_size = len(x) + if not self.enable_tsm: + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + grid_sizes = [[ + torch.zeros_like(grid_sizes), grid_sizes, grid_sizes + ]] + seq_f = 0 + else: + grid_sizes = [] + for idx in sample_idx[0][::-1][::self.sample_c]: + tsm_frame_grid_sizes = [[ + torch.tensor([idx, 0, + 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([idx + 1, seq_h, + seq_w]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([1, seq_h, + seq_w]).unsqueeze(0).repeat(batch_size, 1), + ]] + grid_sizes += tsm_frame_grid_sizes + seq_f = sample_idx[0][-1] + 1 + + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + x = torch.cat([u for u in x]) + + batch_size = len(x) + + token_grid_sizes = [[ + torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor( + [seq_f + 1, self.motion_side_len, + self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor( + [1 if not self.trainable_token_pos_emb else -1, seq_h, + seq_w]).unsqueeze(0).repeat(batch_size, 1), + ] # 第三行代表rope emb的想要覆盖到的范围 + ] + + grid_sizes = grid_sizes + token_grid_sizes + token_unpatch_grid_sizes = torch.stack([ + torch.tensor([1, 32, 32], dtype=torch.long) + for b in range(batch_size) + ]) + token_len = self.token.shape[1] + token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous() + seq_lens = seq_lens + torch.tensor([t.size(0) for t in token], + dtype=torch.long) + x = torch.cat([x, token], dim=1) + # arguments + kwargs = dict( + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=freqs, + ) + + # grad ckpt args + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs, **kwargs): + if return_dict is not None: + return module(*inputs, **kwargs, return_dict=return_dict) + else: + return module(*inputs, **kwargs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ({ + "use_reentrant": False + } if is_torch_version(">=", "1.11.0") else {}) + + for idx, block in enumerate(self.blocks): + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + **kwargs, + **ckpt_kwargs, + ) + else: + x = block(x, **kwargs) + # head + out = x[:, -token_len:] + return out + + def unpatchify(self, x, grid_sizes): + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + + +class FramePackMotioner(nn.Module): + + def __init__( + self, + inner_dim=1024, + num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + zip_frame_buckets=[ + 1, 2, 16 + ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames + drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.proj = nn.Conv3d( + 16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d( + 16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d( + 16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor( + zip_frame_buckets, dtype=torch.long) + + self.inner_dim = inner_dim + self.num_heads = num_heads + + assert (inner_dim % + num_heads) == 0 and (inner_dim // num_heads) % 2 == 0 + d = inner_dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + self.drop_mode = drop_mode + + def forward(self, motion_latents, add_last_motion=2): + motion_frames = motion_latents[0].shape[1] + mot = [] + mot_remb = [] + for m in motion_latents: + lat_height, lat_width = m.shape[2], m.shape[3] + padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, + lat_width).to( + device=m.device, dtype=m.dtype) + overlap_frame = min(padd_lat.shape[1], m.shape[1]) + if overlap_frame > 0: + padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets. + __len__() - + add_last_motion - + 1].sum() + padd_lat[:, -zero_end_frame:] = 0 + + padd_lat = padd_lat.unsqueeze(0) + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum( + ):, :, :].split( + list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten( + 2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten( + 2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten( + 2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, : + 0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, : + 0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat( + [clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + # rope + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [[ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 8, + lat_width // 8]).unsqueeze(0).repeat(1, 1), + torch.tensor([ + self.zip_frame_buckets[2], lat_height // 2, lat_width // 2 + ]).unsqueeze(0).repeat(1, 1), + ]] + + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + + motion_rope_emb = rope_precompute( + motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, + self.inner_dim // self.num_heads), + grid_sizes, + self.freqs, + start=None) + + mot.append(motion_lat) + mot_remb.append(motion_rope_emb) + return mot, mot_remb + +class WanTimeTextAudioPoseEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + audio_embed_dim: int, + enable_adain: bool = True, + pose_embed_dim: Optional[int] = None, + patch_size: Optional[Tuple[int]] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + self.casual_audio_encoder = CausalAudioEncoder(dim=audio_embed_dim, out_dim=dim, num_audio_token=4, need_global=enable_adain) + + self.pose_embedder = None + if pose_embed_dim is not None: + self.pose_embedder = nn.Conv3d(pose_embed_dim, dim, kernel_size=patch_size, stride=patch_size) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + pose_hidden_states: Optional[torch.Tensor] = None, + timestep_seq_len: Optional[int] = None, + ): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (-1, timestep_seq_len)) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + + audio_hidden_states = self.casual_audio_encoder(audio_hidden_states) + + if self.pose_embedder is not None: + pose_hidden_states = self.pose_embedder(pose_hidden_states) + + return temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states + + + class WanS2VTransformerBlock(nn.Module): def __init__( self, @@ -186,15 +819,19 @@ def __init__( out_channels: int = 16, text_dim: int = 4096, freq_dim: int = 256, + audio_dim: int = 1280, + enable_adain: bool = True, + pose_dim: int = 1280, ffn_dim: int = 13824, num_layers: int = 40, cross_attn_norm: bool = True, qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, - image_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, - pos_embed_seq_len: Optional[int] = None, + enable_motioner: bool = False, + enable_framepack: bool = False, + add_last_motion: bool = False, ) -> None: super().__init__() @@ -205,15 +842,57 @@ def __init__( self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + # init motioner + if enable_motioner and enable_framepack: + raise ValueError( + "enable_motioner and enable_framepack are mutually exclusive, please set one of them to False" + ) + self.enable_motioner = enable_motioner + self.add_last_motion = add_last_motion + if enable_motioner: + motioner_dim = 2048 + self.motioner = MotionerTransformers( + patch_size=(2, 4, 4), + dim=motioner_dim, + ffn_dim=motioner_dim, + freq_dim=256, + out_dim=16, + num_heads=16, + num_layers=13, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + motion_token_num=motion_token_num, + enable_tsm=enable_tsm, + motion_stride=4, + expand_ratio=2, + ) + self.zip_motion_out = torch.nn.Sequential( + WanLayerNorm(motioner_dim), + zero_module(nn.Linear(motioner_dim, self.dim))) + + self.enable_framepack = enable_framepack + if enable_framepack: + self.frame_packer = FramePackMotioner( + inner_dim=self.dim, + num_heads=self.num_heads, + zip_frame_buckets=[1, 2, 16], + drop_mode=framepack_drop_mode) + + self.trainable_cond_mask = nn.Embedding(3, self.dim) + # 2. Condition embeddings # image_embedding_dim=1280 for I2V model - self.condition_embedder = WanTimeTextImageEmbedding( + self.condition_embedder = WanTimeTextAudioPoseEmbedding( dim=inner_dim, time_freq_dim=freq_dim, time_proj_dim=inner_dim * 6, text_embed_dim=text_dim, - image_embed_dim=image_dim, - pos_embed_seq_len=pos_embed_seq_len, + audio_embed_dim=audio_dim, + pose_embed_dim=pose_dim, + patch_size=patch_size, + enable_adain=enable_adain, ) # 3. Transformer blocks @@ -226,6 +905,18 @@ def __init__( ] ) + self.audio_injector = AudioInjector( + all_modules, + all_modules_names, + dim=self.dim, + num_heads=self.num_heads, + inject_layer=audio_inject_layers, + root_net=self, + enable_adain=enable_adain, + adain_dim=self.dim, + need_adain_ont=adain_mode != "attn_norm", + ) + # 4. Output norm & projection self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) @@ -233,14 +924,57 @@ def __init__( self.gradient_checkpointing = False + def inject_motion(self, + x, + seq_lens, + rope_embs, + mask_input, + motion_latents, + drop_motion_frames=False, + add_last_motion=True): + # inject the motion frames token to the hidden states + if self.enable_motioner: + mot, mot_remb = self.process_motion_transformer_motioner( + motion_latents, + drop_motion_frames=drop_motion_frames, + add_last_motion=add_last_motion) + elif self.enable_framepack: + mot, mot_remb = self.process_motion_frame_pack( + motion_latents, + drop_motion_frames=drop_motion_frames, + add_last_motion=add_last_motion) + else: + mot, mot_remb = self.process_motion( + motion_latents, drop_motion_frames=drop_motion_frames) + + if len(mot) > 0: + x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)] + seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], + dtype=torch.long) + rope_embs = [ + torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb) + ] + mask_input = [ + torch.cat([ + m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], + device=m.device, + dtype=m.dtype) + ], + dim=1) for m, u in zip(mask_input, x) + ] + return x, seq_lens, rope_embs, mask_input + def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, - encoder_hidden_states_image: Optional[torch.Tensor] = None, - control_hidden_states: torch.Tensor = None, - control_hidden_states_scale: torch.Tensor = None, + motion_latents: torch.Tensor, + image_latents: torch.Tensor = None, + pose_latents: torch.Tensor = None, + audio_embeds: torch.Tensor = None, + motion_frames: List[int] = None, + drop_motion_frames: bool = False, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: @@ -265,15 +999,6 @@ def forward( post_patch_height = height // p_h post_patch_width = width // p_w - if control_hidden_states_scale is None: - control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers)) - control_hidden_states_scale = torch.unbind(control_hidden_states_scale) - if len(control_hidden_states_scale) != len(self.config.vace_layers): - raise ValueError( - f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be " - f"equal to {len(self.config.vace_layers)}." - ) - # 1. Rotary position embedding rotary_emb = self.rope(hidden_states) @@ -281,56 +1006,23 @@ def forward( hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) - control_hidden_states = self.vace_patch_embedding(control_hidden_states) - control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2) - control_hidden_states_padding = control_hidden_states.new_zeros( - batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2) - ) - control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1) - # 3. Time embedding - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image + temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states = self.condition_embedder( + timestep, encoder_hidden_states, audio_embeds, pose_latents ) timestep_proj = timestep_proj.unflatten(1, (6, -1)) # 4. Image embedding - if encoder_hidden_states_image is not None: - encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) # 5. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - # Prepare VACE hints - control_hidden_states_list = [] - for i, block in enumerate(self.vace_blocks): - conditioning_states, control_hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb - ) - control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) - control_hidden_states_list = control_hidden_states_list[::-1] - for i, block in enumerate(self.blocks): hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) - if i in self.config.vace_layers: - control_hint, scale = control_hidden_states_list.pop() - hidden_states = hidden_states + control_hint * scale else: - # Prepare VACE hints - control_hidden_states_list = [] - for i, block in enumerate(self.vace_blocks): - conditioning_states, control_hidden_states = block( - hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb - ) - control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) - control_hidden_states_list = control_hidden_states_list[::-1] - for i, block in enumerate(self.blocks): hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - if i in self.config.vace_layers: - control_hint, scale = control_hidden_states_list.pop() - hidden_states = hidden_states + control_hint * scale # 6. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) From 79252291796990a2d10c08fb0d6c4dc1ecdd3a5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 1 Sep 2025 18:24:32 +0300 Subject: [PATCH 017/131] Refactor `WanS2V` transformer and introduce FramePack motioner The `MotionerTransformers` module is removed and its functionality is replaced by a `FramePackMotioner` module and a simplified standard motion processing pipeline. The codebase is refactored to remove the `einops` dependency, replacing `rearrange` operations with standard PyTorch tensor manipulations for better code consistency. Additionally, `AdaLayerNorm` is introduced for improved conditioning, and helper functions for Rotary Positional Embeddings (RoPE) are added (probably temporarily) and refactored for clarity and flexibility. The audio injection mechanism is also updated to align with the new model structure. --- .../transformers/transformer_wan_s2v.py | 902 ++++++++---------- 1 file changed, 386 insertions(+), 516 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 9fde11ed895e..c3065e6522bd 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -15,6 +15,7 @@ import math from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -38,29 +39,157 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class CausalConv1d(nn.Module): +def torch_dfs(model: nn.Module, parent_name="root"): + module_names, modules = [], [] + current_name = parent_name if parent_name else "root" + module_names.append(current_name) + modules.append(model) + + for name, child in model.named_children(): + if parent_name: + child_name = f"{parent_name}.{name}" + else: + child_name = name + child_modules, child_names = torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +def rope_precompute(x, grid_sizes, freqs, start=None): + b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64)) + seq_bucket = [0] + if type(grid_sizes) is not list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if type(g) is not list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat( + [ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + output[i, seq_bucket[-1] : seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output + + +@torch.amp.autocast("cuda", enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`, *optional*): The size of the embeddings dictionary. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + chunk_dim (`int`, defaults to `0`): + """ - def __init__(self, - chan_in, - chan_out, - kernel_size=3, - stride=1, - dilation=1, - pad_mode='replicate', - **kwargs): + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX and OmniGen for now. + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + x = self.norm(x) * (1 + scale) + shift + return x + + +class CausalConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): super().__init__() self.pad_mode = pad_mode padding = (kernel_size - 1, 0) # T self.time_causal_padding = padding - self.conv = nn.Conv1d( - chan_in, - chan_out, - kernel_size, - stride=stride, - dilation=dilation, - **kwargs) + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) @@ -68,554 +197,299 @@ def forward(self, x): class MotionEncoder_tc(nn.Module): - - def __init__(self, - in_dim: int, - hidden_dim: int, - num_heads=int, - need_global=True, - dtype=None, - device=None): + def __init__( + self, in_dim: int, hidden_dim: int, num_attention_heads=int, need_global=True, dtype=None, device=None + ): factory_kwargs = {"dtype": dtype, "device": device} super().__init__() - self.num_heads = num_heads + self.num_attention_heads = num_attention_heads self.need_global = need_global - self.conv1_local = CausalConv1d( - in_dim, hidden_dim // 4 * num_heads, 3, stride=1) + self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_attention_heads, 3, stride=1) if need_global: - self.conv1_global = CausalConv1d( - in_dim, hidden_dim // 4, 3, stride=1) - self.norm1 = nn.LayerNorm( - hidden_dim // 4, - elementwise_affine=False, - eps=1e-6, - **factory_kwargs) + self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.act = nn.SiLU() self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) if need_global: - self.final_linear = nn.Linear(hidden_dim, hidden_dim, - **factory_kwargs) + self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) - self.norm1 = nn.LayerNorm( - hidden_dim // 4, - elementwise_affine=False, - eps=1e-6, - **factory_kwargs) + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.norm2 = nn.LayerNorm( - hidden_dim // 2, - elementwise_affine=False, - eps=1e-6, - **factory_kwargs) + self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.norm3 = nn.LayerNorm( - hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) def forward(self, x): - x = rearrange(x, 'b t c -> b c t') - x_ori = x.clone() - b, c, t = x.shape + x = x.permute(0, 2, 1) + residual = x.clone() + batch_size, num_channels, seq_len = x.shape x = self.conv1_local(x) - x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = ( + x.unflatten(1, (self.num_attention_heads, -1)) + .permute(0, 1, 3, 2) + .reshape(batch_size * self.num_attention_heads, seq_len, num_channels) + ) x = self.norm1(x) x = self.act(x) - x = rearrange(x, 'b t c -> b c t') + x = x.permute(0, 2, 1) x = self.conv2(x) - x = rearrange(x, 'b c t -> b t c') + x = x.permute(0, 2, 1) x = self.norm2(x) x = self.act(x) - x = rearrange(x, 'b t c -> b c t') + x = x.permute(0, 2, 1) x = self.conv3(x) - x = rearrange(x, 'b c t -> b t c') + x = x.permute(0, 2, 1) x = self.norm3(x) x = self.act(x) - x = rearrange(x, '(b n) t c -> b t n c', b=b) - padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1) x = torch.cat([x, padding], dim=-2) x_local = x.clone() if not self.need_global: return x_local - x = self.conv1_global(x_ori) - x = rearrange(x, 'b c t -> b t c') + x = self.conv1_global(residual) + x = x.permute(0, 2, 1) x = self.norm1(x) x = self.act(x) - x = rearrange(x, 'b t c -> b c t') + x = x.permute(0, 2, 1) x = self.conv2(x) - x = rearrange(x, 'b c t -> b t c') + x = x.permute(0, 2, 1) x = self.norm2(x) x = self.act(x) - x = rearrange(x, 'b t c -> b c t') + x = x.permute(0, 2, 1) x = self.conv3(x) - x = rearrange(x, 'b c t -> b t c') + x = x.permute(0, 2, 1) x = self.norm3(x) x = self.act(x) x = self.final_linear(x) - x = rearrange(x, '(b n) t c -> b t n c', b=b) + x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) return x, x_local class CausalAudioEncoder(nn.Module): - - def __init__(self, - dim=5120, - num_layers=25, - out_dim=2048, - video_rate=8, - num_audio_token=4, - need_global=False): + def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_audio_token=4, need_global=False): super().__init__() self.encoder = MotionEncoder_tc( - in_dim=dim, - hidden_dim=out_dim, - num_heads=num_audio_token, - need_global=need_global) + in_dim=dim, hidden_dim=out_dim, num_heads=num_audio_token, need_global=need_global + ) weight = torch.ones((1, num_layers, 1, 1)) * 0.01 self.weights = torch.nn.Parameter(weight) self.act = torch.nn.SiLU() def forward(self, features): - with amp.autocast(dtype=torch.float32): - # features B * num_layers * dim * video_length - weights = self.act(self.weights) - weights_sum = weights.sum(dim=1, keepdims=True) - weighted_feat = ((features * weights) / weights_sum).sum( - dim=1) # b dim f - weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim - res = self.encoder(weighted_feat) # b f n dim + # features B * num_layers * dim * video_length + weights = self.act(self.weights) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim return res # b f n dim class AudioInjector(nn.Module): - - def __init__(self, - all_modules, - all_modules_names, - dim=2048, - num_heads=32, - inject_layer=[0, 27], - root_net=None, - enable_adain=False, - adain_dim=2048, - need_adain_ont=False): + def __init__( + self, + all_modules, + all_modules_names, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + enable_adain=False, + adain_dim=2048, + need_adain_ont=False, + eps=1e-6, + added_kv_proj_dim=None, + ): super().__init__() - num_injector_layers = len(inject_layer) self.injected_block_id = {} audio_injector_id = 0 for mod_name, mod in zip(all_modules_names, all_modules): - if isinstance(mod, WanAttentionBlock): + if isinstance(mod, WanAttention): for inject_id in inject_layer: - if f'transformer_blocks.{inject_id}' in mod_name: + if f"transformer_blocks.{inject_id}" in mod_name: self.injected_block_id[inject_id] = audio_injector_id audio_injector_id += 1 - self.injector = nn.ModuleList([ - AudioCrossAttention( - dim=dim, - num_heads=num_heads, - qk_norm=True, - ) for _ in range(audio_injector_id) - ]) - self.injector_pre_norm_feat = nn.ModuleList([ - nn.LayerNorm( - dim, - elementwise_affine=False, - eps=1e-6, - ) for _ in range(audio_injector_id) - ]) - self.injector_pre_norm_vec = nn.ModuleList([ - nn.LayerNorm( - dim, - elementwise_affine=False, - eps=1e-6, - ) for _ in range(audio_injector_id) - ]) - if enable_adain: - self.injector_adain_layers = nn.ModuleList([ - AdaLayerNorm( - output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1) + # 2. Cross-attention + self.injector = nn.ModuleList( + [ + WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + processor=WanAttnProcessor(), + ) for _ in range(audio_injector_id) - ]) - if need_adain_ont: - self.injector_adain_output_layers = nn.ModuleList( - [nn.Linear(dim, dim) for _ in range(audio_injector_id)]) - - - -class MotionerTransformers(nn.Module, PeftAdapterMixin): - - def __init__( - self, - patch_size=(1, 2, 2), - in_dim=16, - dim=2048, - ffn_dim=8192, - freq_dim=256, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6, - self_attn_block="SelfAttention", - motion_token_num=1024, - enable_tsm=False, - motion_stride=4, - expand_ratio=2, - trainable_token_pos_emb=False, - ): - super().__init__() - self.patch_size = patch_size - self.in_dim = in_dim - self.dim = dim - self.ffn_dim = ffn_dim - self.freq_dim = freq_dim - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps - - self.enable_tsm = enable_tsm - self.motion_stride = motion_stride - self.expand_ratio = expand_ratio - self.sample_c = self.patch_size[0] - - # embeddings - self.patch_embedding = nn.Conv3d( - in_dim, dim, kernel_size=patch_size, stride=patch_size) - - # blocks - self.blocks = nn.ModuleList([ - MotionerAttentionBlock( - dim, - ffn_dim, - num_heads, - window_size, - qk_norm, - cross_attn_norm, - eps, - self_attn_block=self_attn_block) for _ in range(num_layers) - ]) - - # buffers (don't use register_buffer otherwise dtype will be changed in to()) - assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 - d = dim // num_heads - self.freqs = torch.cat([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)) - ], - dim=1) - - self.gradient_checkpointing = False - - self.motion_side_len = int(math.sqrt(motion_token_num)) - assert self.motion_side_len**2 == motion_token_num - self.token = nn.Parameter( - torch.zeros(1, motion_token_num, dim).contiguous()) - - self.trainable_token_pos_emb = trainable_token_pos_emb - if trainable_token_pos_emb: - x = torch.zeros([1, motion_token_num, num_heads, d]) - x[..., ::2] = 1 - - gride_sizes = [[ - torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1), - torch.tensor([1, self.motion_side_len, - self.motion_side_len]).unsqueeze(0).repeat(1, 1), - torch.tensor([1, self.motion_side_len, - self.motion_side_len]).unsqueeze(0).repeat(1, 1), - ]] - token_freqs = rope_apply(x, gride_sizes, self.freqs) - token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2) - token_freqs = token_freqs * 0.01 - self.token_freqs = torch.nn.Parameter(token_freqs) - - def after_patch_embedding(self, x): - return x - - def forward( - self, - x, - ): - """ - x: A list of videos each with shape [C, T, H, W]. - t: [B]. - context: A list of text embeddings each with shape [L, C]. - """ - # params - motion_frames = x[0].shape[1] - device = self.patch_embedding.weight.device - freqs = self.freqs - if freqs.device != device: - freqs = freqs.to(device) - - if self.trainable_token_pos_emb: - with amp.autocast(dtype=torch.float64): - token_freqs = self.token_freqs.to(torch.float64) - token_freqs = token_freqs / token_freqs.norm( - dim=-1, keepdim=True) - freqs = [freqs, torch.view_as_complex(token_freqs)] - - if self.enable_tsm: - sample_idx = [ - sample_indices( - u.shape[1], - stride=self.motion_stride, - expand_ratio=self.expand_ratio, - c=self.sample_c) for u in x ] - x = [ - torch.flip(torch.flip(u, [1])[:, idx], [1]) - for idx, u in zip(sample_idx, x) + ) + self.injector_pre_norm_feat = nn.ModuleList( + [ + nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) + for _ in range(audio_injector_id) ] - - # embeddings - x = [self.patch_embedding(u.unsqueeze(0)) for u in x] - x = self.after_patch_embedding(x) - - seq_f, seq_h, seq_w = x[0].shape[-3:] - batch_size = len(x) - if not self.enable_tsm: - grid_sizes = torch.stack( - [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) - grid_sizes = [[ - torch.zeros_like(grid_sizes), grid_sizes, grid_sizes - ]] - seq_f = 0 - else: - grid_sizes = [] - for idx in sample_idx[0][::-1][::self.sample_c]: - tsm_frame_grid_sizes = [[ - torch.tensor([idx, 0, - 0]).unsqueeze(0).repeat(batch_size, 1), - torch.tensor([idx + 1, seq_h, - seq_w]).unsqueeze(0).repeat(batch_size, 1), - torch.tensor([1, seq_h, - seq_w]).unsqueeze(0).repeat(batch_size, 1), - ]] - grid_sizes += tsm_frame_grid_sizes - seq_f = sample_idx[0][-1] + 1 - - x = [u.flatten(2).transpose(1, 2) for u in x] - seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) - x = torch.cat([u for u in x]) - - batch_size = len(x) - - token_grid_sizes = [[ - torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1), - torch.tensor( - [seq_f + 1, self.motion_side_len, - self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1), - torch.tensor( - [1 if not self.trainable_token_pos_emb else -1, seq_h, - seq_w]).unsqueeze(0).repeat(batch_size, 1), - ] # 第三行代表rope emb的想要覆盖到的范围 - ] - - grid_sizes = grid_sizes + token_grid_sizes - token_unpatch_grid_sizes = torch.stack([ - torch.tensor([1, 32, 32], dtype=torch.long) - for b in range(batch_size) - ]) - token_len = self.token.shape[1] - token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous() - seq_lens = seq_lens + torch.tensor([t.size(0) for t in token], - dtype=torch.long) - x = torch.cat([x, token], dim=1) - # arguments - kwargs = dict( - seq_lens=seq_lens, - grid_sizes=grid_sizes, - freqs=freqs, ) - - # grad ckpt args - def create_custom_forward(module, return_dict=None): - - def custom_forward(*inputs, **kwargs): - if return_dict is not None: - return module(*inputs, **kwargs, return_dict=return_dict) - else: - return module(*inputs, **kwargs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = ({ - "use_reentrant": False - } if is_torch_version(">=", "1.11.0") else {}) - - for idx, block in enumerate(self.blocks): - if self.training and self.gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - **kwargs, - **ckpt_kwargs, + self.injector_pre_norm_vec = nn.ModuleList( + [ + nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) + for _ in range(audio_injector_id) + ] + ) + if enable_adain: + self.injector_adain_layers = nn.ModuleList( + [ + AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1) + for _ in range(audio_injector_id) + ] + ) + if need_adain_ont: + self.injector_adain_output_layers = nn.ModuleList( + [nn.Linear(dim, dim) for _ in range(audio_injector_id)] ) - else: - x = block(x, **kwargs) - # head - out = x[:, -token_len:] - return out - - def unpatchify(self, x, grid_sizes): - c = self.out_dim - out = [] - for u, v in zip(x, grid_sizes.tolist()): - u = u[:math.prod(v)].view(*v, *self.patch_size, c) - u = torch.einsum('fhwpqrc->cfphqwr', u) - u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) - out.append(u) - return out - - def init_weights(self): - # basic init - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - - # init embeddings - nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) class FramePackMotioner(nn.Module): - def __init__( - self, - inner_dim=1024, - num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design - zip_frame_buckets=[ - 1, 2, 16 - ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames - drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion - *args, - **kwargs): + self, + inner_dim=1024, + num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + zip_frame_buckets=[ + 1, + 2, + 16, + ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames + drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + *args, + **kwargs, + ): super().__init__(*args, **kwargs) - self.proj = nn.Conv3d( - 16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) - self.proj_2x = nn.Conv3d( - 16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) - self.proj_4x = nn.Conv3d( - 16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) - self.zip_frame_buckets = torch.tensor( - zip_frame_buckets, dtype=torch.long) + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) self.inner_dim = inner_dim self.num_heads = num_heads - assert (inner_dim % - num_heads) == 0 and (inner_dim // num_heads) % 2 == 0 + assert (inner_dim % num_heads) == 0 and (inner_dim // num_heads) % 2 == 0 d = inner_dim // num_heads - self.freqs = torch.cat([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)) - ], - dim=1) + self.freqs = torch.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], + dim=1, + ) self.drop_mode = drop_mode def forward(self, motion_latents, add_last_motion=2): - motion_frames = motion_latents[0].shape[1] mot = [] mot_remb = [] for m in motion_latents: lat_height, lat_width = m.shape[2], m.shape[3] - padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, - lat_width).to( - device=m.device, dtype=m.dtype) + padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to( + device=m.device, dtype=m.dtype + ) overlap_frame = min(padd_lat.shape[1], m.shape[1]) if overlap_frame > 0: padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] if add_last_motion < 2 and self.drop_mode != "drop": - zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets. - __len__() - - add_last_motion - - 1].sum() + zero_end_frame = self.zip_frame_buckets[: self.zip_frame_buckets.__len__() - add_last_motion - 1].sum() padd_lat[:, -zero_end_frame:] = 0 padd_lat = padd_lat.unsqueeze(0) - clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum( - ):, :, :].split( - list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1 + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[ + :, :, -self.zip_frame_buckets.sum() :, :, : + ].split(list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1 # patchfy - clean_latents_post = self.proj(clean_latents_post).flatten( - 2).transpose(1, 2) - clean_latents_2x = self.proj_2x(clean_latents_2x).flatten( - 2).transpose(1, 2) - clean_latents_4x = self.proj_4x(clean_latents_4x).flatten( - 2).transpose(1, 2) + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) if add_last_motion < 2 and self.drop_mode == "drop": - clean_latents_post = clean_latents_post[:, : - 0] if add_last_motion < 2 else clean_latents_post - clean_latents_2x = clean_latents_2x[:, : - 0] if add_last_motion < 1 else clean_latents_2x + clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x - motion_lat = torch.cat( - [clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) # rope start_time_id = -(self.zip_frame_buckets[:1].sum()) end_time_id = start_time_id + self.zip_frame_buckets[0] - grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \ - [ - [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), - torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), - torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] - ] + grid_sizes = ( + [] + if add_last_motion < 2 and self.drop_mode == "drop" + else [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]) + .unsqueeze(0) + .repeat(1, 1), + ] + ] + ) start_time_id = -(self.zip_frame_buckets[:2].sum()) end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 - grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \ - [ - [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), - torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), - torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] - ] + grid_sizes_2x = ( + [] + if add_last_motion < 1 and self.drop_mode == "drop" + else [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]) + .unsqueeze(0) + .repeat(1, 1), + ] + ] + ) start_time_id = -(self.zip_frame_buckets[:3].sum()) end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 - grid_sizes_4x = [[ - torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), - torch.tensor([end_time_id, lat_height // 8, - lat_width // 8]).unsqueeze(0).repeat(1, 1), - torch.tensor([ - self.zip_frame_buckets[2], lat_height // 2, lat_width // 2 - ]).unsqueeze(0).repeat(1, 1), - ]] + grid_sizes_4x = [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]) + .unsqueeze(0) + .repeat(1, 1), + ] + ] grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x motion_rope_emb = rope_precompute( - motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, - self.inner_dim // self.num_heads), + motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads), grid_sizes, self.freqs, - start=None) + start=None, + ) mot.append(motion_lat) mot_remb.append(motion_rope_emb) return mot, mot_remb + class WanTimeTextAudioPoseEmbedding(nn.Module): def __init__( self, @@ -635,7 +509,9 @@ def __init__( self.act_fn = nn.SiLU() self.time_proj = nn.Linear(dim, time_proj_dim) self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") - self.casual_audio_encoder = CausalAudioEncoder(dim=audio_embed_dim, out_dim=dim, num_audio_token=4, need_global=enable_adain) + self.casual_audio_encoder = CausalAudioEncoder( + dim=audio_embed_dim, out_dim=dim, num_audio_token=4, need_global=enable_adain + ) self.pose_embedder = None if pose_embed_dim is not None: @@ -669,14 +545,12 @@ def forward( return temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states - class WanS2VTransformerBlock(nn.Module): def __init__( self, dim: int, ffn_dim: int, num_heads: int, - qk_norm: str = "rms_norm_across_heads", cross_attn_norm: bool = False, eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, @@ -768,7 +642,7 @@ def forward( class WanS2VTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" - A Transformer model for video-like data used in the Wan model. + A Transformer model for video-like data used in the Wan2.2-S2V model. Args: patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): @@ -806,7 +680,7 @@ class WanS2VTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanS2VTransformerBlock"] - _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3", "causal_audio_encoder"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config @@ -820,70 +694,48 @@ def __init__( text_dim: int = 4096, freq_dim: int = 256, audio_dim: int = 1280, + audio_inject_layers: List[int] = [0, 4, 8, 12, 16, 20, 24, 27], enable_adain: bool = True, + adain_mode: str = "attn_norm", pose_dim: int = 1280, ffn_dim: int = 13824, num_layers: int = 40, cross_attn_norm: bool = True, - qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, enable_motioner: bool = False, enable_framepack: bool = False, + framepack_drop_mode="padd", add_last_motion: bool = False, ) -> None: super().__init__() - inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) - # init motioner + # Init motioner if enable_motioner and enable_framepack: raise ValueError( "enable_motioner and enable_framepack are mutually exclusive, please set one of them to False" ) - self.enable_motioner = enable_motioner self.add_last_motion = add_last_motion - if enable_motioner: - motioner_dim = 2048 - self.motioner = MotionerTransformers( - patch_size=(2, 4, 4), - dim=motioner_dim, - ffn_dim=motioner_dim, - freq_dim=256, - out_dim=16, - num_heads=16, - num_layers=13, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6, - motion_token_num=motion_token_num, - enable_tsm=enable_tsm, - motion_stride=4, - expand_ratio=2, - ) - self.zip_motion_out = torch.nn.Sequential( - WanLayerNorm(motioner_dim), - zero_module(nn.Linear(motioner_dim, self.dim))) - self.enable_framepack = enable_framepack if enable_framepack: self.frame_packer = FramePackMotioner( - inner_dim=self.dim, - num_heads=self.num_heads, + inner_dim=inner_dim, + num_heads=num_attention_heads, zip_frame_buckets=[1, 2, 16], - drop_mode=framepack_drop_mode) + drop_mode=framepack_drop_mode, + ) - self.trainable_cond_mask = nn.Embedding(3, self.dim) + self.trainable_condition_mask = nn.Embedding(3, inner_dim) - # 2. Condition embeddings - # image_embedding_dim=1280 for I2V model + # 2. Condition Embeddings self.condition_embedder = WanTimeTextAudioPoseEmbedding( dim=inner_dim, time_freq_dim=freq_dim, @@ -899,21 +751,22 @@ def __init__( self.blocks = nn.ModuleList( [ WanS2VTransformerBlock( - inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps, added_kv_proj_dim ) for _ in range(num_layers) ] ) - + # 4. Audio Injector + all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") self.audio_injector = AudioInjector( all_modules, all_modules_names, - dim=self.dim, - num_heads=self.num_heads, + dim=inner_dim, + num_heads=num_attention_heads, inject_layer=audio_inject_layers, root_net=self, enable_adain=enable_adain, - adain_dim=self.dim, + adain_dim=inner_dim, need_adain_ont=adain_mode != "attn_norm", ) @@ -924,43 +777,60 @@ def __init__( self.gradient_checkpointing = False - def inject_motion(self, - x, - seq_lens, - rope_embs, - mask_input, - motion_latents, - drop_motion_frames=False, - add_last_motion=True): - # inject the motion frames token to the hidden states - if self.enable_motioner: - mot, mot_remb = self.process_motion_transformer_motioner( - motion_latents, - drop_motion_frames=drop_motion_frames, - add_last_motion=add_last_motion) - elif self.enable_framepack: - mot, mot_remb = self.process_motion_frame_pack( - motion_latents, - drop_motion_frames=drop_motion_frames, - add_last_motion=add_last_motion) + def process_motion(self, motion_latents, drop_motion_frames=False): + if drop_motion_frames or motion_latents[0].shape[1] == 0: + return [], [] + self.latent_motion_frames = motion_latents[0].shape[1] + mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents] + batch_size = len(mot) + + mot_remb = [] + flattern_mot = [] + for bs in range(batch_size): + height, width = mot[bs].shape[3], mot[bs].shape[4] + flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous() + motion_grid_sizes = [ + [ + torch.tensor([-self.latent_motion_frames, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.latent_motion_frames, height, width]).unsqueeze(0).repeat(1, 1), + ] + ] + motion_rope_emb = rope_precompute( + flat_mot.detach().view( + 1, flat_mot.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads + ), + motion_grid_sizes, + self.freqs, + start=None, + ) + mot_remb.append(motion_rope_emb) + flattern_mot.append(flat_mot) + return flattern_mot, mot_remb + + def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): + flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) + if drop_motion_frames: + return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] else: - mot, mot_remb = self.process_motion( - motion_latents, drop_motion_frames=drop_motion_frames) + return flattern_mot, mot_remb + + def inject_motion( + self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=False, add_last_motion=True + ): + # Inject the motion frames token to the hidden states + if self.enable_framepack: + mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames, add_last_motion) + else: + mot, mot_remb = self.process_motion(motion_latents, drop_motion_frames) if len(mot) > 0: x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)] - seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], - dtype=torch.long) - rope_embs = [ - torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb) - ] + seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) + rope_embs = [torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)] mask_input = [ - torch.cat([ - m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], - device=m.device, - dtype=m.dtype) - ], - dim=1) for m, u in zip(mask_input, x) + torch.cat([m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], device=m.device, dtype=m.dtype)], dim=1) + for m, u in zip(mask_input, x) ] return x, seq_lens, rope_embs, mask_input From 323049dc10151212ac0b1dc29afd0ce77efb0738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 2 Sep 2025 19:41:25 +0300 Subject: [PATCH 018/131] Removes unused code from the speech-to-video pipeline Removes the calculation of several unused variables and an unnecessary `deepcopy` operation on the latents tensor. This change also removes the now-unused `deepcopy` import, simplifying the overall logic. --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 7399538d7f07..a74fbe39acf3 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -14,7 +14,6 @@ import html import math -from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -899,10 +898,6 @@ def __call__( else: latents = latents_outputs - latent_target_frames = (num_frames_per_chunk + 3 + self.motion_frames) // 4 - latent_motion_frames - target_shape = [latent_target_frames, height // 8, width // 8] - max_seq_len = np.prod(target_shape) // 4 - latents = deepcopy(latents) with torch.no_grad(): left_idx = r * num_frames_per_chunk right_idx = r * num_frames_per_chunk + num_frames_per_chunk From 6515b2309c2cb966160d261482c18f3d41c73f9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 2 Sep 2025 19:43:41 +0300 Subject: [PATCH 019/131] Refactor WanS2VTransformer and improve conditioning Refactors the `WanS2VTransformer3DModel` for clarity and better handling of various conditioning inputs like audio, pose, and motion. Key changes: - Simplifies the `WanS2VTransformerBlock` by removing projection layers and streamlining the forward pass. - Introduces `after_transformer_block` to cleanly inject audio information after each transformer block, improving code organization. - Enhances the main `forward` method to better process and combine multiple conditioning signals (image, audio, motion) before the transformer blocks. - Adds support for a zero-value timestep to differentiate between image and video latents. - Generalizes temporal embedding logic to support multiple model variations. --- .../transformers/transformer_wan_s2v.py | 282 +++++++++++++----- 1 file changed, 200 insertions(+), 82 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index c3065e6522bd..058c8d56c04b 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps @@ -320,7 +321,7 @@ def __init__( self.injected_block_id[inject_id] = audio_injector_id audio_injector_id += 1 - # 2. Cross-attention + # Cross-attention self.injector = nn.ModuleList( [ WanAttention( @@ -336,24 +337,10 @@ def __init__( ] ) self.injector_pre_norm_feat = nn.ModuleList( - [ - nn.LayerNorm( - dim, - elementwise_affine=False, - eps=1e-6, - ) - for _ in range(audio_injector_id) - ] + [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(audio_injector_id)] ) self.injector_pre_norm_vec = nn.ModuleList( - [ - nn.LayerNorm( - dim, - elementwise_affine=False, - eps=1e-6, - ) - for _ in range(audio_injector_id) - ] + [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(audio_injector_id)] ) if enable_adain: self.injector_adain_layers = nn.ModuleList( @@ -418,9 +405,9 @@ def forward(self, motion_latents, add_last_motion=2): padd_lat = padd_lat.unsqueeze(0) clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[ :, :, -self.zip_frame_buckets.sum() :, :, : - ].split(list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1 + ].split(list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2, 1 - # patchfy + # patchify clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) @@ -545,99 +532,92 @@ def forward( return temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states +@maybe_allow_in_graph class WanS2VTransformerBlock(nn.Module): def __init__( self, dim: int, ffn_dim: int, num_heads: int, + qk_norm: str = "rms_norm_across_heads", cross_attn_norm: bool = False, eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, - apply_input_projection: bool = False, - apply_output_projection: bool = False, ): super().__init__() - # 1. Input projection - self.proj_in = None - if apply_input_projection: - self.proj_in = nn.Linear(dim, dim) - - # 2. Self-attention + # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) self.attn1 = WanAttention( dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, + cross_attention_dim_head=None, processor=WanAttnProcessor(), ) - # 3. Cross-attention + # 2. Cross-attention self.attn2 = WanAttention( dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, processor=WanAttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() - # 4. Feed-forward + # 3. Feed-forward self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) - # 5. Output projection - self.proj_out = None - if apply_output_projection: - self.proj_out = nn.Linear(dim, dim) - self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - control_hidden_states: torch.Tensor, temb: torch.Tensor, rotary_emb: torch.Tensor, ) -> torch.Tensor: - if self.proj_in is not None: - control_hidden_states = self.proj_in(control_hidden_states) - control_hidden_states = control_hidden_states + hidden_states - - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table + temb.float() - ).chunk(6, dim=1) + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) # 1. Self-attention - norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( - control_hidden_states - ) + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) - control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention - norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) - control_hidden_states = control_hidden_states + attn_output + hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( - control_hidden_states + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states ) ff_output = self.ffn(norm_hidden_states) - control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as( - control_hidden_states - ) - - conditioning_states = None - if self.proj_out is not None: - conditioning_states = self.proj_out(control_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) - return conditioning_states, control_hidden_states + return hidden_states class WanS2VTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): @@ -675,6 +655,8 @@ class WanS2VTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr Whether to use img_emb. added_kv_proj_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the added key and value projections. If `None`, no projection is used. + zero_timestep (`bool`, defaults to `True`): + Whether to assign 0 value timestep to image/motion """ _supports_gradient_checkpointing = True @@ -704,13 +686,14 @@ def __init__( eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, - enable_motioner: bool = False, enable_framepack: bool = False, - framepack_drop_mode="padd", + framepack_drop_mode: str = "padd", add_last_motion: bool = False, + zero_timestep: bool = True, ) -> None: super().__init__() + self.add_last_motion = add_last_motion self.inner_dim = inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -718,13 +701,6 @@ def __init__( self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) - # Init motioner - if enable_motioner and enable_framepack: - raise ValueError( - "enable_motioner and enable_framepack are mutually exclusive, please set one of them to False" - ) - self.add_last_motion = add_last_motion - if enable_framepack: self.frame_packer = FramePackMotioner( inner_dim=inner_dim, @@ -756,6 +732,7 @@ def __init__( for _ in range(num_layers) ] ) + # 4. Audio Injector all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") self.audio_injector = AudioInjector( @@ -764,10 +741,10 @@ def __init__( dim=inner_dim, num_heads=num_attention_heads, inject_layer=audio_inject_layers, - root_net=self, enable_adain=enable_adain, adain_dim=inner_dim, need_adain_ont=adain_mode != "attn_norm", + eps=eps, ) # 4. Output norm & projection @@ -816,7 +793,14 @@ def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, ad return flattern_mot, mot_remb def inject_motion( - self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=False, add_last_motion=True + self, + hidden_states, + seq_lens, + rope_embs, + mask_input, + motion_latents, + drop_motion_frames=False, + add_last_motion=True, ): # Inject the motion frames token to the hidden states if self.enable_framepack: @@ -825,14 +809,49 @@ def inject_motion( mot, mot_remb = self.process_motion(motion_latents, drop_motion_frames) if len(mot) > 0: - x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)] + hidden_states = [torch.cat([u, m], dim=1) for u, m in zip(hidden_states, mot)] seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) rope_embs = [torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)] mask_input = [ torch.cat([m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], device=m.device, dtype=m.dtype)], dim=1) - for m, u in zip(mask_input, x) + for m, u in zip(mask_input, hidden_states) ] - return x, seq_lens, rope_embs, mask_input + return hidden_states, seq_lens, rope_embs, mask_input + + def after_transformer_block( + self, + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ): + if block_idx in self.audio_injector.injected_block_id.keys(): + audio_attn_id = self.audio_injector.injected_block_id[block_idx] + + input_hidden_states = hidden_states[:, :original_sequence_length].clone() # B (F H W) C + input_hidden_states = input_hidden_states.unflatten(1, (merged_audio_emb_num_frames, -1)).flatten(0, 1) + + if self.enbale_adain and self.adain_mode == "attn_norm": + attn_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id]( + input_hidden_states, temb=audio_emb_global[:, 0] + ) + else: + attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + + residual_out = self.audio_injector.injector[audio_attn_id]( + x=attn_hidden_states, + context=attn_audio_emb, + context_lens=torch.ones( + attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device + ) + * attn_audio_emb.shape[1], + ) + residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) + hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out + + return hidden_states def forward( self, @@ -840,14 +859,25 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, motion_latents: torch.Tensor, + audio_embeds: torch.Tensor, image_latents: torch.Tensor = None, pose_latents: torch.Tensor = None, - audio_embeds: torch.Tensor = None, - motion_frames: List[int] = None, + motion_frames: List[int] = [17, 5], drop_motion_frames: bool = False, + add_last_motion: int = 2, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + ... audio_embeds The input audio embedding [B, num_wav2vec_layer, C_a, T_a]. motion_frames The number of motion + frames and motion latents frames encoded by vae, i.e. [17, 5] add_last_motion For the motioner, if + add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added. + For frame packing, the behavior depends on the value of add_last_motion: add_last_motion = + 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. + add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included. + add_last_motion = 2: All motion-related latents are used. + drop_motion_frames Bool, whether drop the motion frames info + """ if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -868,31 +898,119 @@ def forward( post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h post_patch_width = width // p_w - - # 1. Rotary position embedding - rotary_emb = self.rope(hidden_states) + add_last_motion = self.add_last_motion * add_last_motion # 2. Patch embedding hidden_states = self.patch_embedding(hidden_states) - hidden_states = hidden_states.flatten(2).transpose(1, 2) - # 3. Time embedding + # 3. Condition embeddings + audio_embeds = torch.cat([audio_embeds[..., 0].repeat(1, 1, 1, motion_frames[0]), audio_embeds], dim=-1) + + if self.config.zero_timestep: + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states = self.condition_embedder( timestep, encoder_hidden_states, audio_embeds, pose_latents ) timestep_proj = timestep_proj.unflatten(1, (6, -1)) - # 4. Image embedding + if self.enable_adain: + audio_emb_global, audio_emb = audio_hidden_states + audio_emb_global = audio_emb_global[:, motion_frames[1] :].clone() + else: + audio_emb = audio_hidden_states + merged_audio_emb = audio_emb[:, motion_frames[1] :, :] + + hidden_states = hidden_states + pose_hidden_states + grid_sizes = torch.tensor( + [post_patch_num_frames, post_patch_height, post_patch_width], dtype=torch.long + ).unsqueeze(0) + grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] + hidden_states = hidden_states.flatten(2).transpose(1, 2) + sequence_length = hidden_states.shape[1].to(torch.long) + original_sequence_length = sequence_length + + if self.config.zero_timestep: + temb = temb[:-1] + zero_timestep_proj = timestep_proj[-1:] + timestep_proj = timestep_proj[:-1] + timestep_proj = torch.cat( + [timestep_proj.unsqueeze(2), zero_timestep_proj.unsqueeze(2).repeat(timestep_proj.shape[0], 1, 1, 1)], + dim=2, + ) + timestep_proj = [timestep_proj, original_sequence_length] + else: + timestep_proj = timestep_proj.unsqueeze(2).repeat(1, 1, 2, 1) + timestep_proj = [timestep_proj, 0] + + image_latents = self.patch_embedding(image_latents) + image_latents = image_latents.flatten(2).transpose(1, 2) + image_grid_sizes = [ + [ + # The start index + torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + # The end index + torch.tensor([31, height, width]).unsqueeze(0).repeat(batch_size, 1), + # The range + torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1), + ] + ] + + sequence_length = sequence_length + image_latents.shape[1].to(torch.long) + grid_sizes = grid_sizes + image_grid_sizes + + hidden_states = torch.cat([hidden_states, image_latents], dim=1) + + # Initialize masks to indicate noisy latent, image latent, and motion latent. + # However, at this point, only the first two (noisy and image latents) are marked; + # the marking of motion latent will be implemented inside `inject_motion`. + mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device) + mask_input[:, original_sequence_length:] = 1 + + # Rotary position embedding + rotary_emb = self.rope(hidden_states) + + hidden_states, sequence_length, pre_compute_freqs, mask_input = self.inject_motion( + hidden_states, + sequence_length, + # pre_compute_freqs, + mask_input, + motion_latents, + drop_motion_frames, + add_last_motion, + ) + + hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype) + + merged_audio_emb_num_frames = merged_audio_emb.shape[1] # B F N C + attn_audio_emb = merged_audio_emb.flatten(0, 1) + audio_emb_global = audio_emb_global.flatten(0, 1) # 5. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - for i, block in enumerate(self.blocks): + for block_idx, block in enumerate(self.blocks): hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) + hidden_states = self.after_transformer_block( + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ) else: - for i, block in enumerate(self.blocks): + for block_idx, block in enumerate(self.blocks): hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states = self.after_transformer_block( + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ) # 6. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) From fe5a626a0342e5dd50a8620756002da1f2877ca1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 2 Sep 2025 19:48:57 +0300 Subject: [PATCH 020/131] Add `AttentionMixin` to `WanS2VTransformer3DModel` --- src/diffusers/models/transformers/transformer_wan_s2v.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 058c8d56c04b..874829d76cfa 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -24,7 +24,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward +from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput @@ -620,7 +620,9 @@ def forward( return hidden_states -class WanS2VTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): +class WanS2VTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): r""" A Transformer model for video-like data used in the Wan2.2-S2V model. From bb5f10ab8d106cf2ffd57a82e989c5aa84250c41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 2 Sep 2025 23:12:59 +0300 Subject: [PATCH 021/131] fix: Update parameter name for audio encoder to `num_attention_heads` --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 874829d76cfa..1766ed40dcf5 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -279,7 +279,7 @@ class CausalAudioEncoder(nn.Module): def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_audio_token=4, need_global=False): super().__init__() self.encoder = MotionEncoder_tc( - in_dim=dim, hidden_dim=out_dim, num_heads=num_audio_token, need_global=need_global + in_dim=dim, hidden_dim=out_dim, num_attention_heads=num_audio_token, need_global=need_global ) weight = torch.ones((1, num_layers, 1, 1)) * 0.01 From bb5f4c946f1a8b0b34e1a70d90e7e8bfeb77ab51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 2 Sep 2025 23:13:44 +0300 Subject: [PATCH 022/131] feat: Improve support for S2V model conversion Introduces the necessary configurations and state dictionary key mappings to enable the conversion of S2V model checkpoints to the Diffusers format. This includes: - A new transformer configuration for the S2V model architecture, including parameters for audio and pose conditioning. - A comprehensive rename dictionary to map the original S2V layer names to their Diffusers equivalents. --- scripts/convert_wan_to_diffusers.py | 66 ++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 613189e30ea7..5682f62d1dcb 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -114,8 +114,51 @@ "after_proj": "proj_out", } +S2V_TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", + # S2V-specific audio component mappings + "casual_audio_encoder.encoder": "condition_embedder.casual_audio_encoder.encoder", + "casual_audio_encoder.weights": "condition_embedder.casual_audio_encoder.weights", + # Pose condition encoder mappings + "cond_encoder.weight": "condition_embedder.pose_embedder.weight", + "cond_encoder.bias": "condition_embedder.pose_embedder.bias", + "trainable_cond_mask": "trainable_condition_mask", +} + TRANSFORMER_SPECIAL_KEYS_REMAP = {} VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} +S2V_TRANSFORMER_SPECIAL_KEYS_REMAP = {} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: @@ -358,19 +401,28 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: "attention_head_dim": 128, "cross_attn_norm": True, "eps": 1e-06, - "ffn_dim": 14336, + "ffn_dim": 13824, "freq_dim": 256, - "in_channels": 48, - "num_attention_heads": 24, - "num_layers": 30, - "out_channels": 48, + "in_channels": 16, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, "patch_size": [1, 2, 2], "qk_norm": "rms_norm_across_heads", "text_dim": 4096, + "audio_dim": 1024, + "audio_inject_layers": [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + "enable_adain": True, + "adain_mode": "attn_norm", + "pose_dim": 1280, + "enable_framepack": True, + "framepack_drop_mode": "padd", + "add_last_motion": True, + "zero_timestep": True, }, } - RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT - SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + RENAME_DICT = S2V_TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = S2V_TRANSFORMER_SPECIAL_KEYS_REMAP return config, RENAME_DICT, SPECIAL_KEYS_REMAP From f6fb523ac413cfebfea0f4e08f5b7e912c5118a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 09:25:00 +0300 Subject: [PATCH 023/131] simplify --- .../transformers/transformer_wan_s2v.py | 67 ++++++++++--------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 1766ed40dcf5..6d9d45024574 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -187,8 +187,7 @@ def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_m super().__init__() self.pad_mode = pad_mode - padding = (kernel_size - 1, 0) # T - self.time_causal_padding = padding + self.time_causal_padding = (kernel_size - 1, 0) # T self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) @@ -198,10 +197,7 @@ def forward(self, x): class MotionEncoder_tc(nn.Module): - def __init__( - self, in_dim: int, hidden_dim: int, num_attention_heads=int, need_global=True, dtype=None, device=None - ): - factory_kwargs = {"dtype": dtype, "device": device} + def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads=int, need_global=True): super().__init__() self.num_attention_heads = num_attention_heads @@ -209,19 +205,16 @@ def __init__( self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_attention_heads, 3, stride=1) if need_global: self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) - self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.act = nn.SiLU() self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) if need_global: - self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) - - self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) - - self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.final_linear = nn.Linear(hidden_dim, hidden_dim) - self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6) + self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) @@ -336,12 +329,14 @@ def __init__( for _ in range(audio_injector_id) ] ) + self.injector_pre_norm_feat = nn.ModuleList( [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(audio_injector_id)] ) self.injector_pre_norm_vec = nn.ModuleList( [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(audio_injector_id)] ) + if enable_adain: self.injector_adain_layers = nn.ModuleList( [ @@ -359,7 +354,7 @@ class FramePackMotioner(nn.Module): def __init__( self, inner_dim=1024, - num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + num_attention_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design zip_frame_buckets=[ 1, 2, @@ -370,21 +365,28 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) + self.inner_dim = inner_dim + self.num_attention_heads = num_attention_heads + if (inner_dim % num_attention_heads) != 0 or (inner_dim // num_attention_heads) % 2 != 0: + raise ValueError( + "inner_dim must be divisible by num_attention_heads and inner_dim // num_attention_heads must be even" + ) + self.drop_mode = drop_mode + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) - self.inner_dim = inner_dim - self.num_heads = num_heads - - assert (inner_dim % num_heads) == 0 and (inner_dim // num_heads) % 2 == 0 - d = inner_dim // num_heads + head_dim = inner_dim // num_attention_heads self.freqs = torch.cat( - [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], + [ + rope_params(1024, head_dim - 4 * (head_dim // 6)), + rope_params(1024, 2 * (head_dim // 6)), + rope_params(1024, 2 * (head_dim // 6)), + ], dim=1, ) - self.drop_mode = drop_mode def forward(self, motion_latents, add_last_motion=2): mot = [] @@ -870,15 +872,20 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - """ - ... audio_embeds The input audio embedding [B, num_wav2vec_layer, C_a, T_a]. motion_frames The number of motion - frames and motion latents frames encoded by vae, i.e. [17, 5] add_last_motion For the motioner, if - add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added. - For frame packing, the behavior depends on the value of add_last_motion: add_last_motion = - 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. - add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included. - add_last_motion = 2: All motion-related latents are used. - drop_motion_frames Bool, whether drop the motion frames info + r""" + Parameters: + audio_embeds: + The input audio embedding [B, num_wav2vec_layer, C_a, T_a]. + motion_frames: + The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]. + add_last_motion: + For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) + will be added. For frame packing, the behavior depends on the value of add_last_motion: add_last_motion + = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. add_last_motion = 1: + Both clean_latents_2x and clean_latents_4x are included. add_last_motion = 2: All motion-related + latents are used. + drop_motion_frames: + Bool, whether drop the motion frames info. """ if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() From dfec1529e1c44075eaaccb97e501836989a78404 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 09:55:05 +0300 Subject: [PATCH 024/131] up --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index a74fbe39acf3..5eb0daecc84a 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -87,7 +87,7 @@ ... negative_prompt=negative_prompt, ... height=height, ... width=width, - ... num_frames=81, + ... num_frames_per_chunk=81, ... guidance_scale=5.0, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=16) @@ -304,7 +304,7 @@ def encode_audio( fixed_start=0, ) batch_audio_eb = [] - audio_sample_stride = int(self.video_rate / fps) + audio_sample_stride = int(video_rate / fps) for bi in batch_idx: if bi < audio_frame_num: chosen_idx = list( @@ -330,7 +330,7 @@ def encode_audio( batch_audio_eb.append(frame_audio_embed) audio_embed_bucket = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) - audio_embed_bucket = audio_embed_bucket.to(self.device, self.param_dtype) + audio_embed_bucket = audio_embed_bucket.to(device, self.config.dtype) audio_embed_bucket = audio_embed_bucket.unsqueeze(0) if len(audio_embed_bucket.shape) == 3: audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) @@ -939,6 +939,7 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, motion_latents=motion_latents_input, image_latents=condition, + pose_latents=pose_latents, audio_embeds=0.0 * audio_embeds_input, motion_frames=[self.motion_frames, latent_motion_frames], drop_motion_frames=self.drop_first_motion and r == 0, From 21cd65fc786559533c1ed2eef4a674f25611fe62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 10:24:06 +0300 Subject: [PATCH 025/131] refactor: Simplify AdaLayerNorm initialization and forward method --- .../transformers/transformer_wan_s2v.py | 33 +++---------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 6d9d45024574..c2f41c2fd32b 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -131,52 +131,32 @@ class AdaLayerNorm(nn.Module): Parameters: embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`, *optional*): The size of the embeddings dictionary. output_dim (`int`, *optional*): norm_elementwise_affine (`bool`, defaults to `False): norm_eps (`bool`, defaults to `False`): - chunk_dim (`int`, defaults to `0`): """ def __init__( self, embedding_dim: int, - num_embeddings: Optional[int] = None, output_dim: Optional[int] = None, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, - chunk_dim: int = 0, ): super().__init__() - self.chunk_dim = chunk_dim output_dim = output_dim or embedding_dim * 2 - if num_embeddings is not None: - self.emb = nn.Embedding(num_embeddings, embedding_dim) - else: - self.emb = None - self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) - def forward( - self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None - ) -> torch.Tensor: - if self.emb is not None: - temb = self.emb(timestep) - + def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: temb = self.linear(self.silu(temb)) - if self.chunk_dim == 1: - # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the - # other if-branch. This branch is specific to CogVideoX and OmniGen for now. - shift, scale = temb.chunk(2, dim=1) - shift = shift[:, None, :] - scale = scale[:, None, :] - else: - scale, shift = temb.chunk(2, dim=0) + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] x = self.norm(x) * (1 + scale) + shift return x @@ -339,10 +319,7 @@ def __init__( if enable_adain: self.injector_adain_layers = nn.ModuleList( - [ - AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1) - for _ in range(audio_injector_id) - ] + [AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)] ) if need_adain_ont: self.injector_adain_output_layers = nn.ModuleList( From 89b9bcb5d52859f91c2ef350dee15e5e18e05b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 11:56:05 +0300 Subject: [PATCH 026/131] fix: Correct parameter value for pose_dim and name for num_attention_heads in transformer configuration --- scripts/convert_wan_to_diffusers.py | 2 +- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 5682f62d1dcb..eb18f6d5daa1 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -414,7 +414,7 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: "audio_inject_layers": [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], "enable_adain": True, "adain_mode": "attn_norm", - "pose_dim": 1280, + "pose_dim": 16, "enable_framepack": True, "framepack_drop_mode": "padd", "add_last_motion": True, diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index c2f41c2fd32b..294728e19d21 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -685,7 +685,7 @@ def __init__( if enable_framepack: self.frame_packer = FramePackMotioner( inner_dim=inner_dim, - num_heads=num_attention_heads, + num_attention_heads=num_attention_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode, ) From 167bd23c49ff387c325ea38cdb411c6949021218 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 13:10:43 +0300 Subject: [PATCH 027/131] fix: Update audio injector to use WanTransformerBlock instead of WanAttention --- src/diffusers/models/transformers/transformer_wan_s2v.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 294728e19d21..8eb573f4304d 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -34,6 +34,7 @@ WanAttention, WanAttnProcessor, WanRotaryPosEmbed, + WanTransformerBlock, ) @@ -288,7 +289,7 @@ def __init__( self.injected_block_id = {} audio_injector_id = 0 for mod_name, mod in zip(all_modules_names, all_modules): - if isinstance(mod, WanAttention): + if isinstance(mod, WanTransformerBlock): for inject_id in inject_layer: if f"transformer_blocks.{inject_id}" in mod_name: self.injected_block_id[inject_id] = audio_injector_id From 9b6bf4b17479fe37eb5bc438054368c92f861306 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 13:53:02 +0300 Subject: [PATCH 028/131] upp --- scripts/convert_wan_to_diffusers.py | 2 +- src/diffusers/models/transformers/transformer_wan_s2v.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index eb18f6d5daa1..8c04a74fe47a 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -1092,4 +1092,4 @@ def get_args(): scheduler=scheduler, ) - pipe.save_pretrained(args.output_path, push_to_hub=True, safe_serialization=True, max_shard_size="5GB") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 8eb573f4304d..7be76cd18514 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -34,7 +34,6 @@ WanAttention, WanAttnProcessor, WanRotaryPosEmbed, - WanTransformerBlock, ) @@ -289,7 +288,7 @@ def __init__( self.injected_block_id = {} audio_injector_id = 0 for mod_name, mod in zip(all_modules_names, all_modules): - if isinstance(mod, WanTransformerBlock): + if isinstance(mod, WanS2VTransformerBlock): for inject_id in inject_layer: if f"transformer_blocks.{inject_id}" in mod_name: self.injected_block_id[inject_id] = audio_injector_id @@ -665,6 +664,7 @@ def __init__( ffn_dim: int = 13824, num_layers: int = 40, cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, @@ -709,7 +709,7 @@ def __init__( self.blocks = nn.ModuleList( [ WanS2VTransformerBlock( - inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps, added_kv_proj_dim + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim ) for _ in range(num_layers) ] From 4bed628e493d903e7ee6ccf5bab344ea068b194b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 13:57:24 +0300 Subject: [PATCH 029/131] feat: Add audio injector attention mappings to transformer key renaming --- scripts/convert_wan_to_diffusers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 8c04a74fe47a..f642d7051c54 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -154,6 +154,12 @@ "cond_encoder.weight": "condition_embedder.pose_embedder.weight", "cond_encoder.bias": "condition_embedder.pose_embedder.bias", "trainable_cond_mask": "trainable_condition_mask", + # Audio injector attention mappings - convert original q/k/v/o format to diffusers format + **{ + f"'audio_injector.injector.'{i}.{src}": f"'audio_injector.injector.'{i}.{dst}" + for i in range(12) + for src, dst in [("q", "to_q"), ("k", "to_k"), ("v", "to_v"), ("o", "to_out.0")] + }, } TRANSFORMER_SPECIAL_KEYS_REMAP = {} From a112328de03fcd1bad1c5c71a514c2e8cd357179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 15:19:48 +0300 Subject: [PATCH 030/131] up docs --- docs/source/en/api/pipelines/wan.md | 6 +++--- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 15 ++++++++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 3754ecdbcb7c..48caf926263c 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -253,7 +253,7 @@ The example below demonstrates how to use the speech-to-video pipeline to genera import numpy as np import torch from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline -from diffusers.utils import export_to_video, load_image, load_audio +from diffusers.utils import export_to_video, load_image, load_audio, load_video from transformers import Wav2Vec2ForCTC @@ -283,8 +283,8 @@ prompt = "CG animation style, a small blue bird takes off from the ground, flapp output = pipe( image=first_frame, audio=audio, sampling_rate=sampling_rate, - prompt=prompt, height=height, width=width, guidance_scale=5.0, - # pose_video=pose_video + prompt=prompt, height=height, width=width, guidance_scale=5.0, num_frames_per_chunk=81, + #pose_video=pose_video ).frames[0] export_to_video(output, "output.mp4", fps=16) ``` diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 5eb0daecc84a..0798f01d1523 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -56,7 +56,7 @@ >>> import torch >>> import numpy as np >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline - >>> from diffusers.utils import export_to_video, load_image, load_audio + >>> from diffusers.utils import export_to_video, load_image, load_audio, load_video >>> # Available models: Wan-AI/Wan2.2-S2V-14B-Diffusers >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" @@ -64,9 +64,16 @@ >>> pipe = WanSpeechToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> image = load_image( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + >>> first_frame = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" ... ) + >>> audio, sampling_rate = load_audio( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" + ... ) + >>> pose_video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_pose_video.mp4" + ... ) + >>> max_area = 480 * 832 >>> aspect_ratio = image.height / image.width >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] @@ -84,6 +91,8 @@ ... prompt=prompt, ... image=image, ... audio=audio, + ... sampling_rate=sampling_rate, + ... # pose_video=pose_video, ... negative_prompt=negative_prompt, ... height=height, ... width=width, From c798d93aaaf359ea859a3b4fa2d5158ac22a7d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 18:06:26 +0300 Subject: [PATCH 031/131] Adapt the `WanS2VTransformerBlock` to handle the new `temb` format, which includes segmentation information. --- .../transformers/transformer_wan_s2v.py | 65 ++++++++++++------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 7be76cd18514..5546be4b27c9 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -558,31 +558,39 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, + temb: List[torch.Tensor, torch.Tensor], rotary_emb: torch.Tensor, ) -> torch.Tensor: - if temb.ndim == 4: - # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table.unsqueeze(0) + temb.float() - ).chunk(6, dim=2) - # batch_size, seq_len, 1, inner_dim - shift_msa = shift_msa.squeeze(2) - scale_msa = scale_msa.squeeze(2) - gate_msa = gate_msa.squeeze(2) - c_shift_msa = c_shift_msa.squeeze(2) - c_scale_msa = c_scale_msa.squeeze(2) - c_gate_msa = c_gate_msa.squeeze(2) - else: - # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table + temb.float() - ).chunk(6, dim=1) + seg_idx = temb[1].item() + seg_idx = min(max(0, seg_idx), hidden_states.shape[1]) + seg_idx = [0, seg_idx, hidden_states.shape[1]] + temb = temb[0] + # temb: batch_size, 6, 2, inner_dim + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(2) + temb.float() + ).chunk(6, dim=1) + # batch_size, 1, seq_len, inner_dim + shift_msa = shift_msa.squeeze(1) + scale_msa = scale_msa.squeeze(1) + gate_msa = gate_msa.squeeze(1) + c_shift_msa = c_shift_msa.squeeze(1) + c_scale_msa = c_scale_msa.squeeze(1) + c_gate_msa = c_gate_msa.squeeze(1) + + norm_hidden_states = self.norm1(hidden_states.float()) + parts = [] + for i in range(2): + parts.append(norm_hidden_states[:, seg_idx[i]:seg_idx[i + 1]] * + (1 + scale_msa[:, i:i + 1]) + shift_msa[:, i:i + 1]) + norm_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states) # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) - hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + z = [] + for i in range(2): + z.append(attn_output[:, seg_idx[i]:seg_idx[i + 1]] * gate_msa[:, i:i + 1]) + attn_output = torch.cat(z, dim=1) + hidden_states = (hidden_states.float() + attn_output).type_as(hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) @@ -590,11 +598,18 @@ def forward( hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( - hidden_states - ) - ff_output = self.ffn(norm_hidden_states) - hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + norm3_hidden_states = self.norm3(hidden_states.float()) + parts = [] + for i in range(2): + parts.append(norm3_hidden_states[:, seg_idx[i]:seg_idx[i + 1]] * + (1 + c_scale_msa[:, i:i + 1]) + c_shift_msa[:, i:i + 1]) + norm3_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states) + ff_output = self.ffn(norm3_hidden_states) + z = [] + for i in range(2): + z.append(ff_output[:, seg_idx[i]:seg_idx[i + 1]] * c_gate_msa[:, i:i + 1]) + ff_output = torch.cat(z, dim=1) + hidden_states = (hidden_states.float() + ff_output.float()).type_as(hidden_states) return hidden_states From d612c412c0491e382ab0c0f7e5b41cc77e967f27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 18:06:46 +0300 Subject: [PATCH 032/131] style --- .../models/transformers/transformer_wan_s2v.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 5546be4b27c9..751569b161be 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -580,15 +580,17 @@ def forward( norm_hidden_states = self.norm1(hidden_states.float()) parts = [] for i in range(2): - parts.append(norm_hidden_states[:, seg_idx[i]:seg_idx[i + 1]] * - (1 + scale_msa[:, i:i + 1]) + shift_msa[:, i:i + 1]) + parts.append( + norm_hidden_states[:, seg_idx[i] : seg_idx[i + 1]] * (1 + scale_msa[:, i : i + 1]) + + shift_msa[:, i : i + 1] + ) norm_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states) # 1. Self-attention attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) z = [] for i in range(2): - z.append(attn_output[:, seg_idx[i]:seg_idx[i + 1]] * gate_msa[:, i:i + 1]) + z.append(attn_output[:, seg_idx[i] : seg_idx[i + 1]] * gate_msa[:, i : i + 1]) attn_output = torch.cat(z, dim=1) hidden_states = (hidden_states.float() + attn_output).type_as(hidden_states) @@ -601,13 +603,15 @@ def forward( norm3_hidden_states = self.norm3(hidden_states.float()) parts = [] for i in range(2): - parts.append(norm3_hidden_states[:, seg_idx[i]:seg_idx[i + 1]] * - (1 + c_scale_msa[:, i:i + 1]) + c_shift_msa[:, i:i + 1]) + parts.append( + norm3_hidden_states[:, seg_idx[i] : seg_idx[i + 1]] * (1 + c_scale_msa[:, i : i + 1]) + + c_shift_msa[:, i : i + 1] + ) norm3_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states) ff_output = self.ffn(norm3_hidden_states) z = [] for i in range(2): - z.append(ff_output[:, seg_idx[i]:seg_idx[i + 1]] * c_gate_msa[:, i:i + 1]) + z.append(ff_output[:, seg_idx[i] : seg_idx[i + 1]] * c_gate_msa[:, i : i + 1]) ff_output = torch.cat(z, dim=1) hidden_states = (hidden_states.float() + ff_output.float()).type_as(hidden_states) From f1ef8facc4cdc550965f0742da211a0666888755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 20:30:06 +0300 Subject: [PATCH 033/131] Simplify --- .../models/transformers/transformer_wan_s2v.py | 4 ++-- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 751569b161be..9bf4ef9bbc12 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -28,7 +28,7 @@ from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin +from ..modeling_utils import ModelMixin, get_parameter_dtype from ..normalization import FP32LayerNorm from .transformer_wan import ( WanAttention, @@ -495,7 +495,7 @@ def forward( if timestep_seq_len is not None: timestep = timestep.unflatten(0, (-1, timestep_seq_len)) - time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + time_embedder_dtype = get_parameter_dtype(self.time_embedder) if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: timestep = timestep.to(time_embedder_dtype) temb = self.time_embedder(timestep).type_as(encoder_hidden_states) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 0798f01d1523..c4d59ba4c608 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -146,7 +146,11 @@ def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed def linear_interpolation(features, input_fps, output_fps, output_len=None): """ - features: shape=[1, T, 512] input_fps: fps for audio, f_a output_fps: fps for video, f_m output_len: video length + Args: + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length """ features = features.transpose(1, 2) # [1, 512, T] seq_len = features.shape[2] / float(input_fps) # T/f_a @@ -604,12 +608,10 @@ def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size cond_tensors = [-torch.ones([1, 3, num_frames_per_chunk, HEIGHT, WIDTH])] pose_condition = [] - for r in range(len(cond_tensors)): - cond = cond_tensors[r] - cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) - cond_lat = torch.stack(self.vae.encode(cond.to(dtype=self.param_dtype, device=self.device)))[ - :, :, 1: - ].cpu() # for mem save + for cond in cond_tensors: + cond = torch.cat([cond[:, :, 0:1], cond], dim=2) + cond = cond.to(dtype=self.config.dtype, device=self._execution_device) + cond_lat = self.vae.encode(cond)[:, :, 1:] pose_condition.append(cond_lat) return pose_condition From 30be7e88f7d597316e3c8cf512ecb3c7f80debbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 20:53:48 +0300 Subject: [PATCH 034/131] Remove unused audio encoder import --- src/diffusers/pipelines/wan/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py index adb716943359..f21a66dbb7e6 100644 --- a/src/diffusers/pipelines/wan/__init__.py +++ b/src/diffusers/pipelines/wan/__init__.py @@ -22,7 +22,6 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["audio_encoder"] = ["WanAudioEncoder"] _import_structure["pipeline_wan"] = ["WanPipeline"] _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"] _import_structure["pipeline_wan_s2v"] = ["WanSpeechToVideoPipeline"] @@ -36,7 +35,6 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .audio_encoder import WanAudioEncoder from .pipeline_wan import WanPipeline from .pipeline_wan_i2v import WanImageToVideoPipeline from .pipeline_wan_s2v import WanSpeechToVideoPipeline From 7ee98eb931e910c61dcc0d46aba61157e47b1663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 21:40:57 +0300 Subject: [PATCH 035/131] Fix typo --- scripts/convert_wan_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index f642d7051c54..1c7bf56ed5c1 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -156,7 +156,7 @@ "trainable_cond_mask": "trainable_condition_mask", # Audio injector attention mappings - convert original q/k/v/o format to diffusers format **{ - f"'audio_injector.injector.'{i}.{src}": f"'audio_injector.injector.'{i}.{dst}" + f"audio_injector.injector.{i}.{src}": f"audio_injector.injector.{i}.{dst}" for i in range(12) for src, dst in [("q", "to_q"), ("k", "to_k"), ("v", "to_v"), ("o", "to_out.0")] }, From 4674ead12fd900431aacd39279c04cda8cd03272 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 3 Sep 2025 23:18:21 +0300 Subject: [PATCH 036/131] up --- .../transformers/transformer_wan_s2v.py | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 9bf4ef9bbc12..9f2ca02bbdcf 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -405,11 +405,9 @@ def forward(self, motion_latents, add_last_motion=2): if add_last_motion < 2 and self.drop_mode == "drop" else [ [ - torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), - torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), - torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]) - .unsqueeze(0) - .repeat(1, 1), + torch.tensor([start_time_id, 0, 0]).unsqueeze(0), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0), ] ] ) @@ -421,11 +419,9 @@ def forward(self, motion_latents, add_last_motion=2): if add_last_motion < 1 and self.drop_mode == "drop" else [ [ - torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), - torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), - torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]) - .unsqueeze(0) - .repeat(1, 1), + torch.tensor([start_time_id, 0, 0]).unsqueeze(0), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0), ] ] ) @@ -434,18 +430,18 @@ def forward(self, motion_latents, add_last_motion=2): end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 grid_sizes_4x = [ [ - torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), - torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1), - torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]) - .unsqueeze(0) - .repeat(1, 1), + torch.tensor([start_time_id, 0, 0]).unsqueeze(0), + torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0), ] ] grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x motion_rope_emb = rope_precompute( - motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads), + motion_lat.detach().view( + 1, motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads + ), grid_sizes, self.freqs, start=None, @@ -475,7 +471,7 @@ def __init__( self.act_fn = nn.SiLU() self.time_proj = nn.Linear(dim, time_proj_dim) self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") - self.casual_audio_encoder = CausalAudioEncoder( + self.causal_audio_encoder = CausalAudioEncoder( dim=audio_embed_dim, out_dim=dim, num_audio_token=4, need_global=enable_adain ) @@ -503,7 +499,7 @@ def forward( encoder_hidden_states = self.text_embedder(encoder_hidden_states) - audio_hidden_states = self.casual_audio_encoder(audio_hidden_states) + audio_hidden_states = self.causal_audio_encoder(audio_hidden_states) if self.pose_embedder is not None: pose_hidden_states = self.pose_embedder(pose_hidden_states) @@ -769,9 +765,9 @@ def process_motion(self, motion_latents, drop_motion_frames=False): flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous() motion_grid_sizes = [ [ - torch.tensor([-self.latent_motion_frames, 0, 0]).unsqueeze(0).repeat(1, 1), - torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1), - torch.tensor([self.latent_motion_frames, height, width]).unsqueeze(0).repeat(1, 1), + torch.tensor([-self.latent_motion_frames, 0, 0]).unsqueeze(0), + torch.tensor([0, height, width]).unsqueeze(0), + torch.tensor([self.latent_motion_frames, height, width]).unsqueeze(0), ] ] motion_rope_emb = rope_precompute( @@ -834,7 +830,7 @@ def after_transformer_block( input_hidden_states = hidden_states[:, :original_sequence_length].clone() # B (F H W) C input_hidden_states = input_hidden_states.unflatten(1, (merged_audio_emb_num_frames, -1)).flatten(0, 1) - if self.enbale_adain and self.adain_mode == "attn_norm": + if self.enable_adain and self.adain_mode == "attn_norm": attn_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id]( input_hidden_states, temb=audio_emb_global[:, 0] ) From 6ee3b85825cce240cba6a922f456079570e03360 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 4 Sep 2025 08:48:30 +0300 Subject: [PATCH 037/131] simplify --- .../transformers/transformer_wan_s2v.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 9f2ca02bbdcf..41deaced4fd0 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -177,7 +177,7 @@ def forward(self, x): class MotionEncoder_tc(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads=int, need_global=True): + def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads: int, need_global: bool = True): super().__init__() self.num_attention_heads = num_attention_heads @@ -203,11 +203,7 @@ def forward(self, x): residual = x.clone() batch_size, num_channels, seq_len = x.shape x = self.conv1_local(x) - x = ( - x.unflatten(1, (self.num_attention_heads, -1)) - .permute(0, 1, 3, 2) - .reshape(batch_size * self.num_attention_heads, seq_len, num_channels) - ) + x = x.unflatten(1, (self.num_attention_heads, -1)).permute(0, 1, 3, 2).flatten(0, 1) x = self.norm1(x) x = self.act(x) x = x.permute(0, 2, 1) @@ -772,7 +768,10 @@ def process_motion(self, motion_latents, drop_motion_frames=False): ] motion_rope_emb = rope_precompute( flat_mot.detach().view( - 1, flat_mot.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads + 1, + flat_mot.shape[1], + self.config.num_attention_heads, + self.inner_dim // self.config.num_attention_heads, ), motion_grid_sizes, self.freqs, @@ -800,7 +799,7 @@ def inject_motion( add_last_motion=True, ): # Inject the motion frames token to the hidden states - if self.enable_framepack: + if self.config.enable_framepack: mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames, add_last_motion) else: mot, mot_remb = self.process_motion(motion_latents, drop_motion_frames) @@ -830,7 +829,7 @@ def after_transformer_block( input_hidden_states = hidden_states[:, :original_sequence_length].clone() # B (F H W) C input_hidden_states = input_hidden_states.unflatten(1, (merged_audio_emb_num_frames, -1)).flatten(0, 1) - if self.enable_adain and self.adain_mode == "attn_norm": + if self.config.enable_adain and self.config.adain_mode == "attn_norm": attn_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id]( input_hidden_states, temb=audio_emb_global[:, 0] ) @@ -916,7 +915,7 @@ def forward( ) timestep_proj = timestep_proj.unflatten(1, (6, -1)) - if self.enable_adain: + if self.config.enable_adain: audio_emb_global, audio_emb = audio_hidden_states audio_emb_global = audio_emb_global[:, motion_frames[1] :].clone() else: From 508cf8d96f31ffda2068cabdad52b870f10a4baf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 4 Sep 2025 15:59:52 +0300 Subject: [PATCH 038/131] Adding rope for hidden states and image --- .../transformers/transformer_wan_s2v.py | 158 +++++++++++++++--- 1 file changed, 131 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 41deaced4fd0..a47cf2b0468f 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -33,7 +33,6 @@ from .transformer_wan import ( WanAttention, WanAttnProcessor, - WanRotaryPosEmbed, ) @@ -503,6 +502,123 @@ def forward( return temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states +class WanS2VRotaryPosEmbed(nn.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, + ) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> torch.Tensor: + grid_sizes = torch.stack([torch.tensor(u.shape[-3:], dtype=torch.long) for u in hidden_states]) + grid_sizes = [torch.zeros_like(grid_sizes), grid_sizes, grid_sizes] + image_grid_sizes = [ + # The start index + torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + # The end index + torch.tensor([31, image_latents.shape[3], image_latents.shape[4]]).unsqueeze(0).repeat(batch_size, 1), + # The range + torch.tensor([1, image_latents.shape[3], image_latents.shape[4]]).unsqueeze(0).repeat(batch_size, 1), + ] + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + split_sizes = [ + self.attention_head_dim - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 3, + self.attention_head_dim // 3, + ] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_list = [] + freqs_sin_list = [] + seq_lens = [[0] * batch_size] + for grid_size in [grid_sizes, image_grid_sizes]: + for i in range(batch_size): + f_0, h_0, w_0 = grid_size[0][i] + f, h, w = grid_size[1][i] + t_f, t_h, t_w = grid_size[2][i] + seq_f, seq_h, seq_w = f - f_0, h - h_0, w - w_0 + seq_len = int(seq_f * seq_h * seq_w) + seq_lens[i].append(seq_len) + + if seq_len > 0: + if t_f > 0: + factor_f = (t_f / seq_f).item() + factor_h = (t_h / seq_h).item() + factor_w = (t_w / seq_w).item() + + # Generate a list of seq_f integers starting from f_0 and ending at math.ceil(factor_f * seq_f.item() + f_0.item()) + if f_0 >= 0: + f_sam = np.linspace(f_0.item(), (t_f + f_0).item() - 1, seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_0.item(), (-t_f - f_0).item() + 1, seq_f).astype(int).tolist() + + h_sam = np.linspace(h_0.item(), (t_h + h_0).item() - 1, seq_h).astype(int).tolist() + w_sam = np.linspace(w_0.item(), (t_w + w_0).item() - 1, seq_w).astype(int).tolist() + + if f_0 * f < 0 or h_0 * h < 0 or w_0 * w < 0: + raise ValueError("The RoPE is not supported for negative dimensions :S") + + freqs_cos_combined = torch.cat([ + freqs_cos[0][f_sam].view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_cos[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_cos[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], dim=-1).reshape(seq_len, 1, -1) + + freqs_sin_combined = torch.cat([ + freqs_sin[0][f_sam].view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_sin[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_sin[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], dim=-1).reshape(seq_len, 1, -1) + + freqs_cos_list.append(freqs_cos_combined) + freqs_sin_list.append(freqs_sin_combined) + + freqs_cos = torch.cat(freqs_cos_list, dim=0) + freqs_sin = torch.cat(freqs_sin_list, dim=0) + + for i in range(batch_size): + for j in range(2): + freqs_cos[seq_lens[i][j]: seq_lens[i][j + 1]] = freqs_cos[seq_lens[i][j]: seq_lens[i][j + 1]].reshape(batch_size, seq_lens[i], 1, -1) + freqs_sin[seq_lens[i][j]: seq_lens[i][j + 1]] = freqs_sin[seq_lens[i][j]: seq_lens[i][j + 1]].reshape(batch_size, seq_lens[i], 1, -1) + + return freqs_cos, freqs_sin + + @maybe_allow_in_graph class WanS2VTransformerBlock(nn.Module): def __init__( @@ -686,12 +802,11 @@ def __init__( ) -> None: super().__init__() - self.add_last_motion = add_last_motion self.inner_dim = inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels # 1. Patch & position embedding - self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.rope = WanS2VRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) if enable_framepack: @@ -856,8 +971,8 @@ def forward( encoder_hidden_states: torch.Tensor, motion_latents: torch.Tensor, audio_embeds: torch.Tensor, - image_latents: torch.Tensor = None, - pose_latents: torch.Tensor = None, + image_latents: torch.Tensor, + pose_latents: torch.Tensor, motion_frames: List[int] = [17, 5], drop_motion_frames: bool = False, add_last_motion: int = 2, @@ -899,10 +1014,11 @@ def forward( post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h post_patch_width = width // p_w - add_last_motion = self.add_last_motion * add_last_motion + add_last_motion = self.config.add_last_motion * add_last_motion # 2. Patch embedding hidden_states = self.patch_embedding(hidden_states) + image_latents = self.patch_embedding(image_latents) # 3. Condition embeddings audio_embeds = torch.cat([audio_embeds[..., 0].repeat(1, 1, 1, motion_frames[0]), audio_embeds], dim=-1) @@ -923,10 +1039,12 @@ def forward( merged_audio_emb = audio_emb[:, motion_frames[1] :, :] hidden_states = hidden_states + pose_hidden_states - grid_sizes = torch.tensor( - [post_patch_num_frames, post_patch_height, post_patch_width], dtype=torch.long - ).unsqueeze(0) - grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] + + hidden_states = [torch.cat([hs, il], dim=1) for hs, il in zip(hidden_states, image_latents)] + + # Rotary position embedding + rotary_emb = self.rope(hidden_states, image_latents) + hidden_states = hidden_states.flatten(2).transpose(1, 2) sequence_length = hidden_states.shape[1].to(torch.long) original_sequence_length = sequence_length @@ -944,21 +1062,10 @@ def forward( timestep_proj = timestep_proj.unsqueeze(2).repeat(1, 1, 2, 1) timestep_proj = [timestep_proj, 0] - image_latents = self.patch_embedding(image_latents) + image_latents = image_latents.flatten(2).transpose(1, 2) - image_grid_sizes = [ - [ - # The start index - torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), - # The end index - torch.tensor([31, height, width]).unsqueeze(0).repeat(batch_size, 1), - # The range - torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1), - ] - ] sequence_length = sequence_length + image_latents.shape[1].to(torch.long) - grid_sizes = grid_sizes + image_grid_sizes hidden_states = torch.cat([hidden_states, image_latents], dim=1) @@ -968,13 +1075,10 @@ def forward( mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device) mask_input[:, original_sequence_length:] = 1 - # Rotary position embedding - rotary_emb = self.rope(hidden_states) - - hidden_states, sequence_length, pre_compute_freqs, mask_input = self.inject_motion( + hidden_states, sequence_length, rotary_emb, mask_input = self.inject_motion( hidden_states, sequence_length, - # pre_compute_freqs, + rotary_emb, mask_input, motion_latents, drop_motion_frames, From 1fcfeba29e3470fe89e3bfe3a532da6942a9c3cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 4 Sep 2025 16:02:57 +0300 Subject: [PATCH 039/131] style --- .../transformers/transformer_wan_s2v.py | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index a47cf2b0468f..d5c75f1ea3e3 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -26,7 +26,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin -from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin, get_parameter_dtype from ..normalization import FP32LayerNorm @@ -539,6 +539,8 @@ def __init__( self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + grid_sizes = torch.stack([torch.tensor(u.shape[-3:], dtype=torch.long) for u in hidden_states]) grid_sizes = [torch.zeros_like(grid_sizes), grid_sizes, grid_sizes] image_grid_sizes = [ @@ -550,10 +552,6 @@ def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> t torch.tensor([1, image_latents.shape[3], image_latents.shape[4]]).unsqueeze(0).repeat(batch_size, 1), ] - batch_size, num_channels, num_frames, height, width = hidden_states.shape - p_t, p_h, p_w = self.patch_size - ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - split_sizes = [ self.attention_head_dim - 2 * (self.attention_head_dim // 3), self.attention_head_dim // 3, @@ -574,37 +572,39 @@ def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> t seq_f, seq_h, seq_w = f - f_0, h - h_0, w - w_0 seq_len = int(seq_f * seq_h * seq_w) seq_lens[i].append(seq_len) - + if seq_len > 0: if t_f > 0: - factor_f = (t_f / seq_f).item() - factor_h = (t_h / seq_h).item() - factor_w = (t_w / seq_w).item() - # Generate a list of seq_f integers starting from f_0 and ending at math.ceil(factor_f * seq_f.item() + f_0.item()) if f_0 >= 0: f_sam = np.linspace(f_0.item(), (t_f + f_0).item() - 1, seq_f).astype(int).tolist() else: f_sam = np.linspace(-f_0.item(), (-t_f - f_0).item() + 1, seq_f).astype(int).tolist() - + h_sam = np.linspace(h_0.item(), (t_h + h_0).item() - 1, seq_h).astype(int).tolist() w_sam = np.linspace(w_0.item(), (t_w + w_0).item() - 1, seq_w).astype(int).tolist() if f_0 * f < 0 or h_0 * h < 0 or w_0 * w < 0: raise ValueError("The RoPE is not supported for negative dimensions :S") - - freqs_cos_combined = torch.cat([ - freqs_cos[0][f_sam].view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs_cos[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs_cos[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), - ], dim=-1).reshape(seq_len, 1, -1) - - freqs_sin_combined = torch.cat([ - freqs_sin[0][f_sam].view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs_sin[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs_sin[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), - ], dim=-1).reshape(seq_len, 1, -1) - + + freqs_cos_combined = torch.cat( + [ + freqs_cos[0][f_sam].view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_cos[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_cos[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + freqs_sin_combined = torch.cat( + [ + freqs_sin[0][f_sam].view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_sin[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_sin[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + freqs_cos_list.append(freqs_cos_combined) freqs_sin_list.append(freqs_sin_combined) @@ -613,8 +613,12 @@ def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> t for i in range(batch_size): for j in range(2): - freqs_cos[seq_lens[i][j]: seq_lens[i][j + 1]] = freqs_cos[seq_lens[i][j]: seq_lens[i][j + 1]].reshape(batch_size, seq_lens[i], 1, -1) - freqs_sin[seq_lens[i][j]: seq_lens[i][j + 1]] = freqs_sin[seq_lens[i][j]: seq_lens[i][j + 1]].reshape(batch_size, seq_lens[i], 1, -1) + freqs_cos[seq_lens[i][j] : seq_lens[i][j + 1]] = freqs_cos[ + seq_lens[i][j] : seq_lens[i][j + 1] + ].reshape(batch_size, seq_lens[i], 1, -1) + freqs_sin[seq_lens[i][j] : seq_lens[i][j + 1]] = freqs_sin[ + seq_lens[i][j] : seq_lens[i][j + 1] + ].reshape(batch_size, seq_lens[i], 1, -1) return freqs_cos, freqs_sin @@ -1062,7 +1066,6 @@ def forward( timestep_proj = timestep_proj.unsqueeze(2).repeat(1, 1, 2, 1) timestep_proj = [timestep_proj, 0] - image_latents = image_latents.flatten(2).transpose(1, 2) sequence_length = sequence_length + image_latents.shape[1].to(torch.long) From 74d63811dcf362fea6f32316921c0efed22182d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 4 Sep 2025 16:13:53 +0300 Subject: [PATCH 040/131] Refactor ropes --- .../transformers/transformer_wan_s2v.py | 90 +------------------ 1 file changed, 4 insertions(+), 86 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index d5c75f1ea3e3..11de70c1eab3 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -56,74 +56,6 @@ def torch_dfs(model: nn.Module, parent_name="root"): return modules, module_names -def rope_precompute(x, grid_sizes, freqs, start=None): - b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 - - # split freqs - if type(freqs) is list: - trainable_freqs = freqs[1] - freqs = freqs[0] - freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) - - # loop over samples - output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64)) - seq_bucket = [0] - if type(grid_sizes) is not list: - grid_sizes = [grid_sizes] - for g in grid_sizes: - if type(g) is not list: - g = [torch.zeros_like(g), g] - batch_size = g[0].shape[0] - for i in range(batch_size): - if start is None: - f_o, h_o, w_o = g[0][i] - else: - f_o, h_o, w_o = start[i] - - f, h, w = g[1][i] - t_f, t_h, t_w = g[2][i] - seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o - seq_len = int(seq_f * seq_h * seq_w) - if seq_len > 0: - if t_f > 0: - # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) - if f_o >= 0: - f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() - else: - f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() - h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() - w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() - - assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 - freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() - freqs_0 = freqs_0.view(seq_f, 1, 1, -1) - - freqs_i = torch.cat( - [ - freqs_0.expand(seq_f, seq_h, seq_w, -1), - freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), - ], - dim=-1, - ).reshape(seq_len, 1, -1) - elif t_f < 0: - freqs_i = trainable_freqs.unsqueeze(1) - # apply rotary embedding - output[i, seq_bucket[-1] : seq_bucket[-1] + seq_len] = freqs_i - seq_bucket.append(seq_bucket[-1] + seq_len) - return output - - -@torch.amp.autocast("cuda", enabled=False) -def rope_params(max_seq_len, dim, theta=10000): - assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) - ) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs - - class AdaLayerNorm(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. @@ -350,15 +282,7 @@ def __init__( self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) - head_dim = inner_dim // num_attention_heads - self.freqs = torch.cat( - [ - rope_params(1024, head_dim - 4 * (head_dim // 6)), - rope_params(1024, 2 * (head_dim // 6)), - rope_params(1024, 2 * (head_dim // 6)), - ], - dim=1, - ) + self.rope = WanS2VRotaryPosEmbed(inner_dim // num_attention_heads, max_seq_len=1024) def forward(self, motion_latents, add_last_motion=2): mot = [] @@ -433,13 +357,11 @@ def forward(self, motion_latents, add_last_motion=2): grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x - motion_rope_emb = rope_precompute( + motion_rope_emb = self.rope( motion_lat.detach().view( 1, motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads ), grid_sizes, - self.freqs, - start=None, ) mot.append(motion_lat) @@ -506,14 +428,12 @@ class WanS2VRotaryPosEmbed(nn.Module): def __init__( self, attention_head_dim: int, - patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0, ): super().__init__() self.attention_head_dim = attention_head_dim - self.patch_size = patch_size self.max_seq_len = max_seq_len h_dim = w_dim = 2 * (attention_head_dim // 6) @@ -810,7 +730,7 @@ def __init__( out_channels = out_channels or in_channels # 1. Patch & position embedding - self.rope = WanS2VRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.rope = WanS2VRotaryPosEmbed(attention_head_dim, rope_max_seq_len) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) if enable_framepack: @@ -885,7 +805,7 @@ def process_motion(self, motion_latents, drop_motion_frames=False): torch.tensor([self.latent_motion_frames, height, width]).unsqueeze(0), ] ] - motion_rope_emb = rope_precompute( + motion_rope_emb = self.rope( flat_mot.detach().view( 1, flat_mot.shape[1], @@ -893,8 +813,6 @@ def process_motion(self, motion_latents, drop_motion_frames=False): self.inner_dim // self.config.num_attention_heads, ), motion_grid_sizes, - self.freqs, - start=None, ) mot_remb.append(motion_rope_emb) flattern_mot.append(flat_mot) From 9cd08bc1a47a0167b6eda2843365988db27fc0c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 4 Sep 2025 18:48:48 +0300 Subject: [PATCH 041/131] refactor --- .../transformers/transformer_wan_s2v.py | 81 +++++++++++-------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 11de70c1eab3..42105b955268 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -265,6 +265,7 @@ def __init__( 16, ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + patch_size=(1, 2, 2), *args, **kwargs, ): @@ -282,7 +283,7 @@ def __init__( self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) - self.rope = WanS2VRotaryPosEmbed(inner_dim // num_attention_heads, max_seq_len=1024) + self.rope = WanS2VRotaryPosEmbed(inner_dim // num_attention_heads, patch_size=patch_size, max_seq_len=1024) def forward(self, motion_latents, add_last_motion=2): mot = [] @@ -428,6 +429,7 @@ class WanS2VRotaryPosEmbed(nn.Module): def __init__( self, attention_head_dim: int, + patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0, ): @@ -458,19 +460,33 @@ def __init__( self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) - def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, image_latents: torch.Tensor, grid_sizes: List[List[torch.Tensor]] + ) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + if grid_sizes is None: + grid_sizes = torch.tensor([ppf, pph, ppw]).unsqueeze(0).repeat(batch_size, 1) + grid_sizes = [torch.zeros_like(grid_sizes), grid_sizes, grid_sizes] + + image_grid_sizes = [ + # The start index + torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + # The end index + torch.tensor([31, image_latents.shape[3] // p_h, image_latents.shape[4] // p_w]) + .unsqueeze(0) + .repeat(batch_size, 1), + # The range + torch.tensor([1, image_latents.shape[3] // p_h, image_latents.shape[4] // p_w]) + .unsqueeze(0) + .repeat(batch_size, 1), + ] - grid_sizes = torch.stack([torch.tensor(u.shape[-3:], dtype=torch.long) for u in hidden_states]) - grid_sizes = [torch.zeros_like(grid_sizes), grid_sizes, grid_sizes] - image_grid_sizes = [ - # The start index - torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), - # The end index - torch.tensor([31, image_latents.shape[3], image_latents.shape[4]]).unsqueeze(0).repeat(batch_size, 1), - # The range - torch.tensor([1, image_latents.shape[3], image_latents.shape[4]]).unsqueeze(0).repeat(batch_size, 1), - ] + grids = [grid_sizes, image_grid_sizes] + else: # FramePack's RoPE + grids = grid_sizes split_sizes = [ self.attention_head_dim - 2 * (self.attention_head_dim // 3), @@ -478,13 +494,13 @@ def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> t self.attention_head_dim // 3, ] - freqs_cos = self.freqs_cos.split(split_sizes, dim=1) - freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + freqs_cos = [self.freqs_cos.split(split_sizes, dim=1) for _ in range(batch_size)] + freqs_sin = [self.freqs_sin.split(split_sizes, dim=1) for _ in range(batch_size)] - freqs_cos_list = [] - freqs_sin_list = [] - seq_lens = [[0] * batch_size] - for grid_size in [grid_sizes, image_grid_sizes]: + freqs_cos_list = [[] for _ in range(batch_size)] + freqs_sin_list = [[] for _ in range(batch_size)] + seq_lens = [[0] for _ in range(batch_size)] + for grid_size in grids: for i in range(batch_size): f_0, h_0, w_0 = grid_size[0][i] f, h, w = grid_size[1][i] @@ -505,7 +521,7 @@ def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> t w_sam = np.linspace(w_0.item(), (t_w + w_0).item() - 1, seq_w).astype(int).tolist() if f_0 * f < 0 or h_0 * h < 0 or w_0 * w < 0: - raise ValueError("The RoPE is not supported for negative dimensions :S") + raise ValueError("The RoPE is not supported for negative dimensions.") freqs_cos_combined = torch.cat( [ @@ -525,18 +541,19 @@ def forward(self, hidden_states: torch.Tensor, image_latents: torch.Tensor) -> t dim=-1, ).reshape(seq_len, 1, -1) - freqs_cos_list.append(freqs_cos_combined) - freqs_sin_list.append(freqs_sin_combined) + freqs_cos_list[i].append(freqs_cos_combined) + freqs_sin_list[i].append(freqs_sin_combined) - freqs_cos = torch.cat(freqs_cos_list, dim=0) - freqs_sin = torch.cat(freqs_sin_list, dim=0) + for i in range(batch_size): + freqs_cos_list[i] = torch.cat(freqs_cos_list[i], dim=0) + freqs_sin_list[i] = torch.cat(freqs_sin_list[i], dim=0) for i in range(batch_size): for j in range(2): - freqs_cos[seq_lens[i][j] : seq_lens[i][j + 1]] = freqs_cos[ + freqs_cos[i][seq_lens[i][j] : seq_lens[i][j + 1]] = freqs_cos_list[i][ seq_lens[i][j] : seq_lens[i][j + 1] ].reshape(batch_size, seq_lens[i], 1, -1) - freqs_sin[seq_lens[i][j] : seq_lens[i][j + 1]] = freqs_sin[ + freqs_sin[i][seq_lens[i][j] : seq_lens[i][j + 1]] = freqs_sin_list[i][ seq_lens[i][j] : seq_lens[i][j + 1] ].reshape(batch_size, seq_lens[i], 1, -1) @@ -730,7 +747,7 @@ def __init__( out_channels = out_channels or in_channels # 1. Patch & position embedding - self.rope = WanS2VRotaryPosEmbed(attention_head_dim, rope_max_seq_len) + self.rope = WanS2VRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) if enable_framepack: @@ -739,6 +756,7 @@ def __init__( num_attention_heads=num_attention_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode, + patch_size=patch_size, ) self.trainable_condition_mask = nn.Embedding(3, inner_dim) @@ -938,6 +956,9 @@ def forward( post_patch_width = width // p_w add_last_motion = self.config.add_last_motion * add_last_motion + # 1. Rotary position embeddings + rotary_emb = self.rope(hidden_states, image_latents) + # 2. Patch embedding hidden_states = self.patch_embedding(hidden_states) image_latents = self.patch_embedding(image_latents) @@ -962,12 +983,10 @@ def forward( hidden_states = hidden_states + pose_hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + image_latents = image_latents.flatten(2).transpose(1, 2) hidden_states = [torch.cat([hs, il], dim=1) for hs, il in zip(hidden_states, image_latents)] - # Rotary position embedding - rotary_emb = self.rope(hidden_states, image_latents) - - hidden_states = hidden_states.flatten(2).transpose(1, 2) sequence_length = hidden_states.shape[1].to(torch.long) original_sequence_length = sequence_length @@ -984,8 +1003,6 @@ def forward( timestep_proj = timestep_proj.unsqueeze(2).repeat(1, 1, 2, 1) timestep_proj = [timestep_proj, 0] - image_latents = image_latents.flatten(2).transpose(1, 2) - sequence_length = sequence_length + image_latents.shape[1].to(torch.long) hidden_states = torch.cat([hidden_states, image_latents], dim=1) From fd3af1db6168795737358abf8106196e242cb6b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 4 Sep 2025 21:11:28 +0300 Subject: [PATCH 042/131] up --- src/diffusers/models/transformers/transformer_wan_s2v.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 42105b955268..b9a4c9ffaf05 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -959,7 +959,7 @@ def forward( # 1. Rotary position embeddings rotary_emb = self.rope(hidden_states, image_latents) - # 2. Patch embedding + # 2. Patch embeddings hidden_states = self.patch_embedding(hidden_states) image_latents = self.patch_embedding(image_latents) @@ -985,7 +985,7 @@ def forward( hidden_states = hidden_states.flatten(2).transpose(1, 2) image_latents = image_latents.flatten(2).transpose(1, 2) - hidden_states = [torch.cat([hs, il], dim=1) for hs, il in zip(hidden_states, image_latents)] + hidden_states = torch.cat([hidden_states, image_latents], dim=1) sequence_length = hidden_states.shape[1].to(torch.long) original_sequence_length = sequence_length @@ -1005,8 +1005,6 @@ def forward( sequence_length = sequence_length + image_latents.shape[1].to(torch.long) - hidden_states = torch.cat([hidden_states, image_latents], dim=1) - # Initialize masks to indicate noisy latent, image latent, and motion latent. # However, at this point, only the first two (noisy and image latents) are marked; # the marking of motion latent will be implemented inside `inject_motion`. From 97991aa68ff935502970d77ff2e6fd7661834f16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 08:27:19 +0300 Subject: [PATCH 043/131] Preserve the lost dimension explicitly --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index b9a4c9ffaf05..ada2c889790e 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -964,7 +964,7 @@ def forward( image_latents = self.patch_embedding(image_latents) # 3. Condition embeddings - audio_embeds = torch.cat([audio_embeds[..., 0].repeat(1, 1, 1, motion_frames[0]), audio_embeds], dim=-1) + audio_embeds = torch.cat([audio_embeds[..., 0].unsqueeze(-1).repeat(1, 1, 1, motion_frames[0]), audio_embeds], dim=-1) if self.config.zero_timestep: timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) From 2048861296a98fd75f0166da9106dd0251212e67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 11:10:29 +0300 Subject: [PATCH 044/131] Use complex rope temporarily --- .../transformers/transformer_wan_s2v.py | 275 ++++++++++++------ 1 file changed, 191 insertions(+), 84 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index ada2c889790e..9f27d5a05918 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -32,7 +33,6 @@ from ..normalization import FP32LayerNorm from .transformer_wan import ( WanAttention, - WanAttnProcessor, ) @@ -56,6 +56,130 @@ def torch_dfs(model: nn.Module, parent_name="root"): return modules, module_names +def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img + + +class WanAttnProcessor: + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + # dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 + n = query.size(2) + # loop over samples + output = [] + for i in range(query.size(0)): + s = query.size(1) + x_i = torch.view_as_complex(query[i, :s].to(torch.float64).reshape(s, n, -1, 2)) + freqs_i = freqs[i, :s] + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, hidden_states[i, s:]]) + # append to collection + output.append(x_i) + return torch.stack(output).float() + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + class AdaLayerNorm(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. @@ -442,26 +566,21 @@ def __init__( t_dim = attention_head_dim - h_dim - w_dim freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 - freqs_cos = [] - freqs_sin = [] + freqs = [] for dim in [t_dim, h_dim, w_dim]: - freq_cos, freq_sin = get_1d_rotary_pos_embed( - dim, - max_seq_len, - theta, - use_real=True, - repeat_interleave_real=True, - freqs_dtype=freqs_dtype, + freq = get_1d_rotary_pos_embed( + dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype ) - freqs_cos.append(freq_cos) - freqs_sin.append(freq_sin) + freqs.append(freq) - self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) - self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + self.freqs = torch.cat(freqs, dim=1) def forward( - self, hidden_states: torch.Tensor, image_latents: torch.Tensor, grid_sizes: List[List[torch.Tensor]] + self, + hidden_states: torch.Tensor, + image_latents: torch.Tensor, + grid_sizes: Optional[List[List[torch.Tensor]]] = None, ) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size @@ -489,75 +608,62 @@ def forward( grids = grid_sizes split_sizes = [ - self.attention_head_dim - 2 * (self.attention_head_dim // 3), - self.attention_head_dim // 3, - self.attention_head_dim // 3, + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 6, + self.attention_head_dim // 6, ] - freqs_cos = [self.freqs_cos.split(split_sizes, dim=1) for _ in range(batch_size)] - freqs_sin = [self.freqs_sin.split(split_sizes, dim=1) for _ in range(batch_size)] + S = ppf * pph * ppw + image_latents.shape[3] // p_h * image_latents.shape[4] // p_w + + num_heads = num_channels // self.attention_head_dim + + freqs = self.freqs.split(split_sizes, dim=1) + + # loop over samples + output = torch.view_as_complex( + torch.zeros((batch_size, S, num_heads, -1, 2), device=hidden_states.device).to(torch.float64) + ) + seq_bucket = [0] - freqs_cos_list = [[] for _ in range(batch_size)] - freqs_sin_list = [[] for _ in range(batch_size)] - seq_lens = [[0] for _ in range(batch_size)] - for grid_size in grids: + for g in grids: + if type(g) is not list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] for i in range(batch_size): - f_0, h_0, w_0 = grid_size[0][i] - f, h, w = grid_size[1][i] - t_f, t_h, t_w = grid_size[2][i] - seq_f, seq_h, seq_w = f - f_0, h - h_0, w - w_0 - seq_len = int(seq_f * seq_h * seq_w) - seq_lens[i].append(seq_len) + f_o, h_o, w_o = g[0][i] + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) if seq_len > 0: if t_f > 0: - # Generate a list of seq_f integers starting from f_0 and ending at math.ceil(factor_f * seq_f.item() + f_0.item()) - if f_0 >= 0: - f_sam = np.linspace(f_0.item(), (t_f + f_0).item() - 1, seq_f).astype(int).tolist() + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() else: - f_sam = np.linspace(-f_0.item(), (-t_f - f_0).item() + 1, seq_f).astype(int).tolist() - - h_sam = np.linspace(h_0.item(), (t_h + h_0).item() - 1, seq_h).astype(int).tolist() - w_sam = np.linspace(w_0.item(), (t_w + w_0).item() - 1, seq_w).astype(int).tolist() + f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() - if f_0 * f < 0 or h_0 * h < 0 or w_0 * w < 0: - raise ValueError("The RoPE is not supported for negative dimensions.") + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) - freqs_cos_combined = torch.cat( + freqs_i = torch.cat( [ - freqs_cos[0][f_sam].view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs_cos[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs_cos[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), ], dim=-1, ).reshape(seq_len, 1, -1) - freqs_sin_combined = torch.cat( - [ - freqs_sin[0][f_sam].view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs_sin[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), - freqs_sin[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), - ], - dim=-1, - ).reshape(seq_len, 1, -1) + # apply rotary embedding + output[i, seq_bucket[-1] : seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) - freqs_cos_list[i].append(freqs_cos_combined) - freqs_sin_list[i].append(freqs_sin_combined) - - for i in range(batch_size): - freqs_cos_list[i] = torch.cat(freqs_cos_list[i], dim=0) - freqs_sin_list[i] = torch.cat(freqs_sin_list[i], dim=0) - - for i in range(batch_size): - for j in range(2): - freqs_cos[i][seq_lens[i][j] : seq_lens[i][j + 1]] = freqs_cos_list[i][ - seq_lens[i][j] : seq_lens[i][j + 1] - ].reshape(batch_size, seq_lens[i], 1, -1) - freqs_sin[i][seq_lens[i][j] : seq_lens[i][j + 1]] = freqs_sin_list[i][ - seq_lens[i][j] : seq_lens[i][j + 1] - ].reshape(batch_size, seq_lens[i], 1, -1) - - return freqs_cos, freqs_sin + return output @maybe_allow_in_graph @@ -964,7 +1070,9 @@ def forward( image_latents = self.patch_embedding(image_latents) # 3. Condition embeddings - audio_embeds = torch.cat([audio_embeds[..., 0].unsqueeze(-1).repeat(1, 1, 1, motion_frames[0]), audio_embeds], dim=-1) + audio_embeds = torch.cat( + [audio_embeds[..., 0].unsqueeze(-1).repeat(1, 1, 1, motion_frames[0]), audio_embeds], dim=-1 + ) if self.config.zero_timestep: timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) @@ -985,25 +1093,11 @@ def forward( hidden_states = hidden_states.flatten(2).transpose(1, 2) image_latents = image_latents.flatten(2).transpose(1, 2) - hidden_states = torch.cat([hidden_states, image_latents], dim=1) sequence_length = hidden_states.shape[1].to(torch.long) original_sequence_length = sequence_length - - if self.config.zero_timestep: - temb = temb[:-1] - zero_timestep_proj = timestep_proj[-1:] - timestep_proj = timestep_proj[:-1] - timestep_proj = torch.cat( - [timestep_proj.unsqueeze(2), zero_timestep_proj.unsqueeze(2).repeat(timestep_proj.shape[0], 1, 1, 1)], - dim=2, - ) - timestep_proj = [timestep_proj, original_sequence_length] - else: - timestep_proj = timestep_proj.unsqueeze(2).repeat(1, 1, 2, 1) - timestep_proj = [timestep_proj, 0] - sequence_length = sequence_length + image_latents.shape[1].to(torch.long) + hidden_states = torch.cat([hidden_states, image_latents], dim=1) # Initialize masks to indicate noisy latent, image latent, and motion latent. # However, at this point, only the first two (noisy and image latents) are marked; @@ -1023,6 +1117,19 @@ def forward( hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype) + if self.config.zero_timestep: + temb = temb[:-1] + zero_timestep_proj = timestep_proj[-1:] + timestep_proj = timestep_proj[:-1] + timestep_proj = torch.cat( + [timestep_proj.unsqueeze(2), zero_timestep_proj.unsqueeze(2).repeat(timestep_proj.shape[0], 1, 1, 1)], + dim=2, + ) + timestep_proj = [timestep_proj, original_sequence_length] + else: + timestep_proj = timestep_proj.unsqueeze(2).repeat(1, 1, 2, 1) + timestep_proj = [timestep_proj, 0] + merged_audio_emb_num_frames = merged_audio_emb.shape[1] # B F N C attn_audio_emb = merged_audio_emb.flatten(0, 1) audio_emb_global = audio_emb_global.flatten(0, 1) From 17166e2122465685109c4dd64ffafbf6cab1370d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 13:34:49 +0300 Subject: [PATCH 045/131] upp --- .../transformers/transformer_wan_s2v.py | 75 ++++++++++--------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 9f27d5a05918..aedc30b8cfea 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -412,25 +412,25 @@ def __init__( def forward(self, motion_latents, add_last_motion=2): mot = [] mot_remb = [] - for m in motion_latents: - lat_height, lat_width = m.shape[2], m.shape[3] - padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to( - device=m.device, dtype=m.dtype + for motion_latent in motion_latents: + latent_height, latent_width = motion_latent.shape[2], motion_latent.shape[3] + padd_latent = torch.zeros(16, self.zip_frame_buckets.sum(), latent_height, latent_width).to( + device=motion_latent.device, dtype=motion_latent.dtype ) - overlap_frame = min(padd_lat.shape[1], m.shape[1]) + overlap_frame = min(padd_latent.shape[1], motion_latent.shape[1]) if overlap_frame > 0: - padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] + padd_latent[:, -overlap_frame:] = motion_latent[:, -overlap_frame:] if add_last_motion < 2 and self.drop_mode != "drop": - zero_end_frame = self.zip_frame_buckets[: self.zip_frame_buckets.__len__() - add_last_motion - 1].sum() - padd_lat[:, -zero_end_frame:] = 0 + zero_end_frame = self.zip_frame_buckets[: len(self.zip_frame_buckets) - add_last_motion - 1].sum() + padd_latent[:, -zero_end_frame:] = 0 - padd_lat = padd_lat.unsqueeze(0) - clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[ + padd_latent = padd_latent.unsqueeze(0) + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_latent[ :, :, -self.zip_frame_buckets.sum() :, :, : ].split(list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2, 1 - # patchify + # Patchify clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) @@ -441,7 +441,7 @@ def forward(self, motion_latents, add_last_motion=2): motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) - # rope + # RoPE start_time_id = -(self.zip_frame_buckets[:1].sum()) end_time_id = start_time_id + self.zip_frame_buckets[0] grid_sizes = ( @@ -450,8 +450,8 @@ def forward(self, motion_latents, add_last_motion=2): else [ [ torch.tensor([start_time_id, 0, 0]).unsqueeze(0), - torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0), - torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 2, latent_width // 2]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[0], latent_height // 2, latent_width // 2]).unsqueeze(0), ] ] ) @@ -464,8 +464,8 @@ def forward(self, motion_latents, add_last_motion=2): else [ [ torch.tensor([start_time_id, 0, 0]).unsqueeze(0), - torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0), - torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 4, latent_width // 4]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[1], latent_height // 2, latent_width // 2]).unsqueeze(0), ] ] ) @@ -475,18 +475,16 @@ def forward(self, motion_latents, add_last_motion=2): grid_sizes_4x = [ [ torch.tensor([start_time_id, 0, 0]).unsqueeze(0), - torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0), - torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 8, latent_width // 8]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[2], latent_height // 2, latent_width // 2]).unsqueeze(0), ] ] grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x motion_rope_emb = self.rope( - motion_lat.detach().view( - 1, motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads - ), - grid_sizes, + motion_lat.detach().view(1, motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads), + grid_sizes=grid_sizes, ) mot.append(motion_lat) @@ -582,17 +580,21 @@ def forward( image_latents: torch.Tensor, grid_sizes: Optional[List[List[torch.Tensor]]] = None, ) -> torch.Tensor: - batch_size, num_channels, num_frames, height, width = hidden_states.shape - p_t, p_h, p_w = self.patch_size - ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + if grid_sizes is None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + grid_sizes = torch.tensor([ppf, pph, ppw]).unsqueeze(0).repeat(batch_size, 1) grid_sizes = [torch.zeros_like(grid_sizes), grid_sizes, grid_sizes] image_grid_sizes = [ # The start index - torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([30, 0, 0]) + .unsqueeze(0) + .repeat(batch_size, 1), # The end index torch.tensor([31, image_latents.shape[3] // p_h, image_latents.shape[4] // p_w]) .unsqueeze(0) @@ -604,7 +606,10 @@ def forward( ] grids = [grid_sizes, image_grid_sizes] + S = ppf * pph * ppw + image_latents.shape[3] // p_h * image_latents.shape[4] // p_w + num_heads = num_channels // self.attention_head_dim else: # FramePack's RoPE + batch_size, S, num_heads, _ = hidden_states.shape grids = grid_sizes split_sizes = [ @@ -613,18 +618,13 @@ def forward( self.attention_head_dim // 6, ] - S = ppf * pph * ppw + image_latents.shape[3] // p_h * image_latents.shape[4] // p_w - - num_heads = num_channels // self.attention_head_dim - freqs = self.freqs.split(split_sizes, dim=1) - # loop over samples + # Loop over samples output = torch.view_as_complex( torch.zeros((batch_size, S, num_heads, -1, 2), device=hidden_states.device).to(torch.float64) ) seq_bucket = [0] - for g in grids: if type(g) is not list: g = [torch.zeros_like(g), g] @@ -944,6 +944,7 @@ def process_motion(self, motion_latents, drop_motion_frames=False): def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) + if drop_motion_frames: return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] else: @@ -1102,8 +1103,8 @@ def forward( # Initialize masks to indicate noisy latent, image latent, and motion latent. # However, at this point, only the first two (noisy and image latents) are marked; # the marking of motion latent will be implemented inside `inject_motion`. - mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device) - mask_input[:, original_sequence_length:] = 1 + mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device).unsqueeze(0).repeat(batch_size, 1, 1) + mask_input[:, :, original_sequence_length:] = 1 hidden_states, sequence_length, rotary_emb, mask_input = self.inject_motion( hidden_states, @@ -1115,6 +1116,10 @@ def forward( add_last_motion, ) + hidden_states = torch.cat(hidden_states) + rotary_emb = torch.cat(rotary_emb) + mask_input = torch.cat(mask_input) + hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype) if self.config.zero_timestep: @@ -1160,6 +1165,8 @@ def forward( audio_emb_global, ) + hidden_states = hidden_states[:, :original_sequence_length] + # 6. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) From b9a7149942f616c8e83b7f64593bdbabdf0ed84c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 13:35:20 +0300 Subject: [PATCH 046/131] style --- .../models/transformers/transformer_wan_s2v.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index aedc30b8cfea..1803f8fa2dfe 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -483,7 +483,9 @@ def forward(self, motion_latents, add_last_motion=2): grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x motion_rope_emb = self.rope( - motion_lat.detach().view(1, motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads), + motion_lat.detach().view( + 1, motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads + ), grid_sizes=grid_sizes, ) @@ -580,8 +582,6 @@ def forward( image_latents: torch.Tensor, grid_sizes: Optional[List[List[torch.Tensor]]] = None, ) -> torch.Tensor: - - if grid_sizes is None: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size @@ -592,9 +592,7 @@ def forward( image_grid_sizes = [ # The start index - torch.tensor([30, 0, 0]) - .unsqueeze(0) - .repeat(batch_size, 1), + torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), # The end index torch.tensor([31, image_latents.shape[3] // p_h, image_latents.shape[4] // p_w]) .unsqueeze(0) @@ -944,7 +942,7 @@ def process_motion(self, motion_latents, drop_motion_frames=False): def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) - + if drop_motion_frames: return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] else: @@ -1103,7 +1101,11 @@ def forward( # Initialize masks to indicate noisy latent, image latent, and motion latent. # However, at this point, only the first two (noisy and image latents) are marked; # the marking of motion latent will be implemented inside `inject_motion`. - mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device).unsqueeze(0).repeat(batch_size, 1, 1) + mask_input = ( + torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + ) mask_input[:, :, original_sequence_length:] = 1 hidden_states, sequence_length, rotary_emb, mask_input = self.inject_motion( From 244005a81fe4a27e3e1482591649e10df3984276 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 14:18:55 +0300 Subject: [PATCH 047/131] fix --- docs/source/en/api/pipelines/wan.md | 2 +- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 3 +-- src/diffusers/utils/__init__.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 48caf926263c..9ece588ccba1 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -258,7 +258,7 @@ from transformers import Wav2Vec2ForCTC model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" -audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) +audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", torch_dtype=torch.float32) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = WanSpeechToVideoPipeline.from_pretrained( model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 1803f8fa2dfe..5d68cbac0812 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -711,7 +711,7 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - temb: List[torch.Tensor, torch.Tensor], + temb: Tuple[torch.Tensor, torch.Tensor], rotary_emb: torch.Tensor, ) -> torch.Tensor: seg_idx = temb[1].item() diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index c4d59ba4c608..d0208f0aa376 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -1005,8 +1005,7 @@ def __call__( # Accumulate latents so as to decode them all at once at the end all_latents.append(segment_latents) - all_latents = torch.cat(all_latents, dim=2) - latents = all_latents + latents = torch.cat(all_latents, dim=2) self._current_timestep = None diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 63932221b207..414be41857fb 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -122,7 +122,7 @@ is_xformers_version, requires_backends, ) -from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video +from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video, load_audio from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( From fde574d2bcf4d8fcba4306cd256c74c2e536a428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 14:49:10 +0300 Subject: [PATCH 048/131] fix: correct key names in S2V transformer mapping for audio components --- scripts/convert_wan_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 1c7bf56ed5c1..7ef1926fed6e 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -148,8 +148,8 @@ "attn2.to_v_img": "attn2.add_v_proj", "attn2.norm_k_img": "attn2.norm_added_k", # S2V-specific audio component mappings - "casual_audio_encoder.encoder": "condition_embedder.casual_audio_encoder.encoder", - "casual_audio_encoder.weights": "condition_embedder.casual_audio_encoder.weights", + "casual_audio_encoder.encoder": "condition_embedder.causal_audio_encoder.encoder", + "casual_audio_encoder.weights": "condition_embedder.causal_audio_encoder.weights", # Pose condition encoder mappings "cond_encoder.weight": "condition_embedder.pose_embedder.weight", "cond_encoder.bias": "condition_embedder.pose_embedder.bias", From 83567bf55932931598a3b5d94865cf42e9d59b77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 16:51:51 +0300 Subject: [PATCH 049/131] fixes --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 8 ++++---- src/diffusers/utils/loading_utils.py | 7 ------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index d0208f0aa376..a82aa1544d4e 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -343,7 +343,7 @@ def encode_audio( batch_audio_eb.append(frame_audio_embed) audio_embed_bucket = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) - audio_embed_bucket = audio_embed_bucket.to(device, self.config.dtype) + audio_embed_bucket = audio_embed_bucket.to(device, self.dtype) audio_embed_bucket = audio_embed_bucket.unsqueeze(0) if len(audio_embed_bucket.shape) == 3: audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) @@ -497,8 +497,8 @@ def check_inputs( raise ValueError( "Provide either `audio` or `audio_embeds`. Cannot leave both `audio` and `audio_embeds` undefined." ) - elif audio is not None and not isinstance(audio, (torch.Tensor, list)): - raise ValueError(f"`audio` has to be of type `torch.Tensor` or `list` but is {type(audio)}") + elif audio is not None and not isinstance(audio, (np.ndarray)): + raise ValueError(f"`audio` has to be of type `np.ndarray` but is {type(audio)}") def prepare_latents( self, @@ -866,7 +866,7 @@ def __call__( ) if num_chunks is None or num_chunks > num_chunks_audio: num_chunks = num_chunks_audio - audio_embeds = audio_embeds.repeat(batch_size, 1, 1) + #audio_embeds = audio_embeds.repeat(batch_size, 1, 1) audio_embeds = audio_embeds.to(transformer_dtype) latent_motion_frames = (self.motion_frames + 3) // 4 diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index 31aec907d712..e7d75d77362a 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -175,13 +175,6 @@ def load_audio( "Incorrect format used for the audio. Should be a URL linking to an audio, a local path, or a PIL audio." ) - # audio = PIL.ImageOps.exif_transpose(audio) - - if convert_method is not None: - audio = convert_method(audio) - else: - audio = audio.convert("RGB") - return audio, sample_rate From 551c74e16ef744973f0dd00ffa63253c0ed2e561 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 17:39:34 +0300 Subject: [PATCH 050/131] Fix errors encountering during inference --- .../models/transformers/transformer_wan_s2v.py | 13 ++++++++----- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 14 ++++++-------- src/diffusers/utils/__init__.py | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 5d68cbac0812..63e0a36f1963 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -560,6 +560,7 @@ def __init__( super().__init__() self.attention_head_dim = attention_head_dim + self.patch_size = patch_size self.max_seq_len = max_seq_len h_dim = w_dim = 2 * (attention_head_dim // 6) @@ -579,7 +580,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - image_latents: torch.Tensor, + image_latents: Optional[torch.Tensor] = None, grid_sizes: Optional[List[List[torch.Tensor]]] = None, ) -> torch.Tensor: if grid_sizes is None: @@ -611,7 +612,7 @@ def forward( grids = grid_sizes split_sizes = [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), self.attention_head_dim // 6, self.attention_head_dim // 6, ] @@ -620,7 +621,9 @@ def forward( # Loop over samples output = torch.view_as_complex( - torch.zeros((batch_size, S, num_heads, -1, 2), device=hidden_states.device).to(torch.float64) + torch.zeros((batch_size, S, num_heads, self.attention_head_dim // 2, 2), device=hidden_states.device).to( + torch.float64 + ) ) seq_bucket = [0] for g in grids: @@ -1093,9 +1096,9 @@ def forward( hidden_states = hidden_states.flatten(2).transpose(1, 2) image_latents = image_latents.flatten(2).transpose(1, 2) - sequence_length = hidden_states.shape[1].to(torch.long) + sequence_length = torch.tensor([hidden_states.shape[1]], dtype=torch.long) original_sequence_length = sequence_length - sequence_length = sequence_length + image_latents.shape[1].to(torch.long) + sequence_length = sequence_length + torch.tensor([image_latents.shape[1]], dtype=torch.long) hidden_states = torch.cat([hidden_states, image_latents], dim=1) # Initialize masks to indicate noisy latent, image latent, and motion latent. diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index a82aa1544d4e..c357a427c261 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -568,9 +568,7 @@ def prepare_latents( latent_condition = latent_condition.to(dtype) latent_condition = (latent_condition - latents_mean) * latents_std - motion_pixels = torch.zeros( - [1, 3, self.motion_frames, height, width], dtype=transformer_dtype, device=device - ) + motion_pixels = torch.zeros([1, 3, self.motion_frames, height, width], dtype=self.vae.dtype, device=device) # Get pose condition input if needed pose_condition = self.load_pose_condition( pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps @@ -579,7 +577,7 @@ def prepare_latents( if init_first_frame: self.drop_first_motion = False motion_pixels[:, :, -6:] = latent_condition - motion_latents = torch.stack(self.vae.encode(motion_pixels)) + motion_latents = retrieve_latents(self.vae.encode(motion_pixels), sample_mode="argmax") videos_last_latents = motion_latents.detach() return latents, latent_condition, videos_last_latents, motion_latents, pose_condition @@ -610,8 +608,8 @@ def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size pose_condition = [] for cond in cond_tensors: cond = torch.cat([cond[:, :, 0:1], cond], dim=2) - cond = cond.to(dtype=self.config.dtype, device=self._execution_device) - cond_lat = self.vae.encode(cond)[:, :, 1:] + cond = cond.to(dtype=self.dtype, device=self._execution_device) + cond_lat = retrieve_latents(self.vae.encode(cond), sample_mode="argmax")[:, :, 1:] pose_condition.append(cond_lat) return pose_condition @@ -866,7 +864,6 @@ def __call__( ) if num_chunks is None or num_chunks > num_chunks_audio: num_chunks = num_chunks_audio - #audio_embeds = audio_embeds.repeat(batch_size, 1, 1) audio_embeds = audio_embeds.to(transformer_dtype) latent_motion_frames = (self.motion_frames + 3) // 4 @@ -915,7 +912,7 @@ def __call__( pose_latents = pose_condition[r] if pose_video else pose_condition[0] pose_latents = pose_latents.to(dtype=transformer_dtype, device=device) audio_embeds_input = audio_embeds[..., left_idx:right_idx] - motion_latents_input = motion_latents.clone() + motion_latents_input = motion_latents.to(transformer_dtype).clone() with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -925,6 +922,7 @@ def __call__( self._current_timestep = t latent_model_input = latents.to(transformer_dtype) + condition = condition.to(transformer_dtype) timestep = t.expand(latents.shape[0]) with self.transformer.cache_context("cond"): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 414be41857fb..6c4290bf1d29 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -122,7 +122,7 @@ is_xformers_version, requires_backends, ) -from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video, load_audio +from .loading_utils import get_module_from_name, get_submodule_by_name, load_audio, load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( From 83412186a051ba70381563cec959d3fde8401161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 20:05:35 +0300 Subject: [PATCH 051/131] up --- docs/source/en/api/pipelines/wan.md | 2 +- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 9ece588ccba1..d61292d9db8a 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -258,7 +258,7 @@ from transformers import Wav2Vec2ForCTC model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" -audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", torch_dtype=torch.float32) +audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = WanSpeechToVideoPipeline.from_pretrained( model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 63e0a36f1963..f054cde0a503 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -968,7 +968,7 @@ def inject_motion( mot, mot_remb = self.process_motion(motion_latents, drop_motion_frames) if len(mot) > 0: - hidden_states = [torch.cat([u, m], dim=1) for u, m in zip(hidden_states, mot)] + hidden_states = [torch.cat([u.unsqueeze(0), m], dim=1) for u, m in zip(hidden_states, mot)] seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) rope_embs = [torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)] mask_input = [ From 6663e584bf9d29431f61ec52b01467a07d2bb709 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 5 Sep 2025 20:17:14 +0300 Subject: [PATCH 052/131] --- src/diffusers/models/transformers/transformer_wan_s2v.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index f054cde0a503..dfa0534f927e 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -970,9 +970,9 @@ def inject_motion( if len(mot) > 0: hidden_states = [torch.cat([u.unsqueeze(0), m], dim=1) for u, m in zip(hidden_states, mot)] seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) - rope_embs = [torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)] + rope_embs = [torch.cat([u.unsqueeze(0), m], dim=1) for u, m in zip(rope_embs, mot_remb)] mask_input = [ - torch.cat([m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], device=m.device, dtype=m.dtype)], dim=1) + torch.cat([m.unsqueeze(0), 2 * torch.ones([1, u.shape[2] - m.shape[2]], device=m.device, dtype=m.dtype)], dim=1) for m, u in zip(mask_input, hidden_states) ] return hidden_states, seq_lens, rope_embs, mask_input From a9b08de2d4e48366519449bdef75532698642e2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 6 Sep 2025 16:55:32 +0300 Subject: [PATCH 053/131] Fix bugs and improve stability in WanSpeechToVideo model This commit addresses several issues in the WanSpeechToVideo pipeline and transformer model to improve robustness and correct behavior. Key changes include: - Fixing rotary position embedding logic by correctly passing and using the number of attention heads. - Ensuring data type consistency across various tensor operations, particularly for rotary embeddings and audio embeddings. - Correcting the normalization of pose and motion latents. - Simplifying and fixing tensor shape manipulations during forward passes. - Removing an unused parameter from the `prepare_latents` method. --- .../transformers/transformer_wan_s2v.py | 49 +++++++++++-------- .../pipelines/wan/pipeline_wan_s2v.py | 17 +++---- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index dfa0534f927e..08758b210902 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -122,19 +122,19 @@ def __call__( def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): # dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 - n = query.size(2) + n = hidden_states.size(2) # loop over samples output = [] - for i in range(query.size(0)): - s = query.size(1) - x_i = torch.view_as_complex(query[i, :s].to(torch.float64).reshape(s, n, -1, 2)) + for i in range(hidden_states.size(0)): + s = hidden_states.size(1) + x_i = torch.view_as_complex(hidden_states[i, :s].to(torch.float64).reshape(s, n, -1, 2)) freqs_i = freqs[i, :s] # apply rotary embedding x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.cat([x_i, hidden_states[i, s:]]) # append to collection output.append(x_i) - return torch.stack(output).float() + return torch.stack(output).type_as(hidden_states) query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) @@ -407,7 +407,12 @@ def __init__( self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) - self.rope = WanS2VRotaryPosEmbed(inner_dim // num_attention_heads, patch_size=patch_size, max_seq_len=1024) + self.rope = WanS2VRotaryPosEmbed( + inner_dim // num_attention_heads, + patch_size=patch_size, + max_seq_len=1024, + num_attention_heads=num_attention_heads, + ) def forward(self, motion_latents, add_last_motion=2): mot = [] @@ -555,6 +560,7 @@ def __init__( attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, + num_attention_heads: int, theta: float = 10000.0, ): super().__init__() @@ -562,6 +568,7 @@ def __init__( self.attention_head_dim = attention_head_dim self.patch_size = patch_size self.max_seq_len = max_seq_len + self.num_attention_heads = num_attention_heads h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim @@ -606,9 +613,8 @@ def forward( grids = [grid_sizes, image_grid_sizes] S = ppf * pph * ppw + image_latents.shape[3] // p_h * image_latents.shape[4] // p_w - num_heads = num_channels // self.attention_head_dim else: # FramePack's RoPE - batch_size, S, num_heads, _ = hidden_states.shape + batch_size, S, _, _ = hidden_states.shape grids = grid_sizes split_sizes = [ @@ -621,9 +627,11 @@ def forward( # Loop over samples output = torch.view_as_complex( - torch.zeros((batch_size, S, num_heads, self.attention_head_dim // 2, 2), device=hidden_states.device).to( - torch.float64 - ) + torch.zeros( + (batch_size, S, self.num_attention_heads, + self.attention_head_dim // 2, + 2), device=hidden_states.device + ).to(torch.float64) ) seq_bucket = [0] for g in grids: @@ -854,7 +862,7 @@ def __init__( out_channels = out_channels or in_channels # 1. Patch & position embedding - self.rope = WanS2VRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.rope = WanS2VRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len, num_attention_heads) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) if enable_framepack: @@ -972,7 +980,10 @@ def inject_motion( seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) rope_embs = [torch.cat([u.unsqueeze(0), m], dim=1) for u, m in zip(rope_embs, mot_remb)] mask_input = [ - torch.cat([m.unsqueeze(0), 2 * torch.ones([1, u.shape[2] - m.shape[2]], device=m.device, dtype=m.dtype)], dim=1) + torch.cat( + [m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], device=m.device, dtype=m.dtype)], + dim=1, + ) for m, u in zip(mask_input, hidden_states) ] return hidden_states, seq_lens, rope_embs, mask_input @@ -1000,12 +1011,8 @@ def after_transformer_block( attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) residual_out = self.audio_injector.injector[audio_attn_id]( - x=attn_hidden_states, - context=attn_audio_emb, - context_lens=torch.ones( - attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device - ) - * attn_audio_emb.shape[1], + attn_hidden_states, + attn_audio_emb, ) residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out @@ -1141,8 +1148,8 @@ def forward( timestep_proj = [timestep_proj, 0] merged_audio_emb_num_frames = merged_audio_emb.shape[1] # B F N C - attn_audio_emb = merged_audio_emb.flatten(0, 1) - audio_emb_global = audio_emb_global.flatten(0, 1) + attn_audio_emb = merged_audio_emb.flatten(0, 1).to(hidden_states.dtype) + audio_emb_global = audio_emb_global.flatten(0, 1).to(hidden_states.dtype) # 5. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index c357a427c261..c59778eb22e6 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -517,7 +517,6 @@ def prepare_latents( init_first_frame: bool = False, num_chunks: int = 1, sampling_fps: int = 16, - transformer_dtype: torch.dtype = torch.bfloat16, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]]: num_latent_frames = ( num_frames_per_chunk + 3 + self.motion_frames @@ -540,12 +539,7 @@ def prepare_latents( if image is not None: image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames_per_chunk - 1, height, width)], - dim=2, - ) - - video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + video_condition = image.to(device=device, dtype=self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -571,20 +565,21 @@ def prepare_latents( motion_pixels = torch.zeros([1, 3, self.motion_frames, height, width], dtype=self.vae.dtype, device=device) # Get pose condition input if needed pose_condition = self.load_pose_condition( - pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps + pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps, latents_mean, latents_std ) # Encode motion latents if init_first_frame: self.drop_first_motion = False motion_pixels[:, :, -6:] = latent_condition motion_latents = retrieve_latents(self.vae.encode(motion_pixels), sample_mode="argmax") + motion_latents = (motion_latents - latents_mean) * latents_std videos_last_latents = motion_latents.detach() return latents, latent_condition, videos_last_latents, motion_latents, pose_condition else: return latents - def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size, sampling_fps): + def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size, sampling_fps, latents_mean, latents_std): HEIGHT, WIDTH = size if pose_video is not None: pose_seq = self.read_last_n_frames( @@ -610,6 +605,7 @@ def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size cond = torch.cat([cond[:, :, 0:1], cond], dim=2) cond = cond.to(dtype=self.dtype, device=self._execution_device) cond_lat = retrieve_latents(self.vae.encode(cond), sample_mode="argmax")[:, :, 1:] + cond_lat = (cond_lat - latents_mean) * latents_std pose_condition.append(cond_lat) return pose_condition @@ -685,7 +681,7 @@ def __call__( height: int = 480, width: int = 832, num_frames_per_chunk: int = 81, - num_inference_steps: int = 50, + num_inference_steps: int = 40, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -898,7 +894,6 @@ def __call__( init_first_frame, num_chunks, sampling_fps, - transformer_dtype, ) if r == 0: From 86123d92c54cf88ea616b32b601d33f9b871dfb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 6 Sep 2025 16:55:42 +0300 Subject: [PATCH 054/131] style --- src/diffusers/models/transformers/transformer_wan_s2v.py | 6 ++---- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 4 +++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 08758b210902..1419e0a16860 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -628,10 +628,8 @@ def forward( # Loop over samples output = torch.view_as_complex( torch.zeros( - (batch_size, S, self.num_attention_heads, - self.attention_head_dim // 2, - 2), device=hidden_states.device - ).to(torch.float64) + (batch_size, S, self.num_attention_heads, self.attention_head_dim // 2, 2), device=hidden_states.device + ).to(torch.float64) ) seq_bucket = [0] for g in grids: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index c59778eb22e6..2834f4e85741 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -579,7 +579,9 @@ def prepare_latents( else: return latents - def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size, sampling_fps, latents_mean, latents_std): + def load_pose_condition( + self, pose_video, num_chunks, num_frames_per_chunk, size, sampling_fps, latents_mean, latents_std + ): HEIGHT, WIDTH = size if pose_video is not None: pose_seq = self.read_last_n_frames( From 8064c42fd5f9dcf84783fac820b0d4b272aab671 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 6 Sep 2025 17:57:25 +0300 Subject: [PATCH 055/131] Enhance load_audio function to support audio loading from URLs using librosa and handle numpy arrays with a default sample rate. --- src/diffusers/utils/loading_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index e7d75d77362a..781dab3afabc 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -161,7 +161,15 @@ def load_audio( """ if isinstance(audio, str): if audio.startswith("http://") or audio.startswith("https://"): - audio = PIL.Image.open(requests.get(audio, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw) + # Download audio from URL and load with librosa + response = requests.get(audio, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT) + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + for chunk in response.iter_content(chunk_size=8192): + temp_file.write(chunk) + temp_audio_path = temp_file.name + + audio, sample_rate = librosa.load(temp_audio_path, sr=16000) + os.remove(temp_audio_path) # Clean up temporary file elif os.path.isfile(audio): audio, sample_rate = librosa.load(audio, sr=16000) else: @@ -170,9 +178,10 @@ def load_audio( ) elif isinstance(audio, numpy.ndarray): audio = audio + sample_rate = 16000 # Default sample rate for numpy arrays else: raise ValueError( - "Incorrect format used for the audio. Should be a URL linking to an audio, a local path, or a PIL audio." + "Incorrect format used for the audio. Should be a URL linking to an audio, a local path, or a numpy array." ) return audio, sample_rate From ac16d5ddc6941804b905114b9cf01dc92729dd32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 09:21:16 +0300 Subject: [PATCH 056/131] upp --- docs/source/en/api/pipelines/wan.md | 2 +- .../models/transformers/transformer_wan_s2v.py | 13 +++++-------- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 2 +- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index d61292d9db8a..afb46c02f19f 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -283,7 +283,7 @@ prompt = "CG animation style, a small blue bird takes off from the ground, flapp output = pipe( image=first_frame, audio=audio, sampling_rate=sampling_rate, - prompt=prompt, height=height, width=width, guidance_scale=5.0, num_frames_per_chunk=81, + prompt=prompt, height=height, width=width, num_frames_per_chunk=81, #pose_video=pose_video ).frames[0] export_to_video(output, "output.mp4", fps=16) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 1419e0a16860..402c810f9325 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -507,9 +507,9 @@ def __init__( time_proj_dim: int, text_embed_dim: int, audio_embed_dim: int, - enable_adain: bool = True, - pose_embed_dim: Optional[int] = None, - patch_size: Optional[Tuple[int]] = None, + pose_embed_dim: int, + patch_size: Tuple[int], + enable_adain: bool, ): super().__init__() @@ -522,9 +522,7 @@ def __init__( dim=audio_embed_dim, out_dim=dim, num_audio_token=4, need_global=enable_adain ) - self.pose_embedder = None - if pose_embed_dim is not None: - self.pose_embedder = nn.Conv3d(pose_embed_dim, dim, kernel_size=patch_size, stride=patch_size) + self.pose_embedder = nn.Conv3d(pose_embed_dim, dim, kernel_size=patch_size, stride=patch_size) def forward( self, @@ -548,8 +546,7 @@ def forward( audio_hidden_states = self.causal_audio_encoder(audio_hidden_states) - if self.pose_embedder is not None: - pose_hidden_states = self.pose_embedder(pose_hidden_states) + pose_hidden_states = self.pose_embedder(pose_hidden_states) return temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 2834f4e85741..47c10416ce83 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -684,7 +684,7 @@ def __call__( width: int = 832, num_frames_per_chunk: int = 81, num_inference_steps: int = 40, - guidance_scale: float = 5.0, + guidance_scale: float = 4.5, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, From dbc07645a7fe2caaad18b40cf03d302f791ff827 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 09:23:06 +0300 Subject: [PATCH 057/131] add _repeated_blocks --- src/diffusers/models/transformers/transformer_wan_s2v.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 402c810f9325..35b7b40149ff 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -823,6 +823,7 @@ class WanS2VTransformer3DModel( _no_split_modules = ["WanS2VTransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3", "causal_audio_encoder"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["WanS2VTransformerBlock"] @register_to_config def __init__( From 80a2fbe94226b1f1b63fd8285a2c7fcdeab8ee84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 11:21:18 +0300 Subject: [PATCH 058/131] up --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 47c10416ce83..26a4a42f4e85 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -866,10 +866,6 @@ def __call__( latent_motion_frames = (self.motion_frames + 3) // 4 - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) @@ -911,6 +907,12 @@ def __call__( audio_embeds_input = audio_embeds[..., left_idx:right_idx] motion_latents_input = motion_latents.to(transformer_dtype).clone() + # 4. Prepare timesteps + self.scheduler = UniPCMultistepScheduler.from_pretrained("tolgacangoz/Wan2.1-T2V-14B-Diffusers", + subfolder="scheduler") + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: From acc8ecb0804b8d4f48197ad202301eeb879d9403 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 11:23:16 +0300 Subject: [PATCH 059/131] fix --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 26a4a42f4e85..954e1fc09b97 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -870,10 +870,6 @@ def __call__( num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - all_latents = [] for r in range(num_chunks): latents_outputs = self.prepare_latents( @@ -912,6 +908,9 @@ def __call__( subfolder="scheduler") self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): From a72903329d246bde7255b9b863a8eaa55c88fde7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 11:24:52 +0300 Subject: [PATCH 060/131] up --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 954e1fc09b97..1fcd37539a23 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -904,7 +904,7 @@ def __call__( motion_latents_input = motion_latents.to(transformer_dtype).clone() # 4. Prepare timesteps - self.scheduler = UniPCMultistepScheduler.from_pretrained("tolgacangoz/Wan2.1-T2V-14B-Diffusers", + self.scheduler = UniPCMultistepScheduler.from_pretrained("tolgacangoz/Wan2.2-S2V-14B-Diffusers", subfolder="scheduler") self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps From a0d521793950274bcd666e7084eaa4f1e0dedd27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 12:24:27 +0300 Subject: [PATCH 061/131] fix previous latensts --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 1fcd37539a23..43a1ae6da331 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -883,7 +883,7 @@ def __call__( torch.float32, device, generator, - latents, + latents if r == 0 else None, pose_video, init_first_frame, num_chunks, From 4fd1014fc46e8d9927d122364c8804614e3b1f39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 19:25:41 +0300 Subject: [PATCH 062/131] set deterministic for fa2 --- src/diffusers/models/transformers/transformer_wan_s2v.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 35b7b40149ff..d062f9871ed5 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -139,6 +139,8 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) + attention_kwargs = {"deterministic": True} + # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: @@ -156,6 +158,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, + attention_kwargs=attention_kwargs, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -168,6 +171,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, + attention_kwargs=attention_kwargs, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) From 3773de3061ab0d64b2116060ef9ce9d355370839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 21:24:28 +0300 Subject: [PATCH 063/131] up --- src/diffusers/models/transformers/transformer_wan_s2v.py | 4 ---- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 7 ++++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index d062f9871ed5..35b7b40149ff 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -139,8 +139,6 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) - attention_kwargs = {"deterministic": True} - # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: @@ -158,7 +156,6 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, - attention_kwargs=attention_kwargs, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -171,7 +168,6 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, - attention_kwargs=attention_kwargs, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 43a1ae6da331..4f5c3dd1b597 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -904,11 +904,12 @@ def __call__( motion_latents_input = motion_latents.to(transformer_dtype).clone() # 4. Prepare timesteps - self.scheduler = UniPCMultistepScheduler.from_pretrained("tolgacangoz/Wan2.2-S2V-14B-Diffusers", - subfolder="scheduler") + self.scheduler = UniPCMultistepScheduler.from_pretrained( + "tolgacangoz/Wan2.2-S2V-14B-Diffusers", subfolder="scheduler" + ) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) From d0e3e268aa933fa099f957f1899ccafe0d70c2fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 7 Sep 2025 21:38:04 +0300 Subject: [PATCH 064/131] update example docstring --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 4f5c3dd1b597..efc7a318b483 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -57,11 +57,15 @@ >>> import numpy as np >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline >>> from diffusers.utils import export_to_video, load_image, load_audio, load_video + >>> from transformers import Wav2Vec2ForCTC >>> # Available models: Wan-AI/Wan2.2-S2V-14B-Diffusers >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) - >>> pipe = WanSpeechToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) + >>> pipe = WanSpeechToVideoPipeline.from_pretrained( + ... model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 + ... ) >>> pipe.to("cuda") >>> first_frame = load_image( @@ -85,7 +89,6 @@ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ... ) >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" - >>> audio = load_audio(...) >>> output = pipe( ... prompt=prompt, @@ -97,7 +100,6 @@ ... height=height, ... width=width, ... num_frames_per_chunk=81, - ... guidance_scale=5.0, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=16) ``` From fe41eddfef40c15894e2c67b19889a466f9b4548 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 08:18:55 +0300 Subject: [PATCH 065/131] upp --- docs/source/en/api/pipelines/wan.md | 79 ++++++++++++++++--- .../transformers/transformer_wan_s2v.py | 4 + .../pipelines/wan/pipeline_wan_s2v.py | 15 ++-- 3 files changed, 77 insertions(+), 21 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index afb46c02f19f..495be3ff880d 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -250,7 +250,7 @@ The example below demonstrates how to use the speech-to-video pipeline to genera ```python -import numpy as np +import numpy as np, math import torch from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline from diffusers.utils import export_to_video, load_image, load_audio, load_video @@ -267,24 +267,81 @@ pipe.to("cuda") first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png") audio, sampling_rate = load_audio("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png") -pose_video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_pose_video.mp4") - -def aspect_ratio_resize(image, pipe, max_area=720 * 1280): - aspect_ratio = image.height / image.width - mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] - height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value +#pose_path = E.g., download from "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" + +def get_size_less_than_area(height, + width, + target_area=1024 * 704, + divisor=64): + if height * width <= target_area: + # If the original image area is already less than or equal to the target, + # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. + max_upper_area = target_area + min_scale = 0.1 + max_scale = 1.0 + else: + # Resize to fit within the target area and then pad to multiples of `divisor` + max_upper_area = target_area # Maximum allowed total pixel count after padding + d = divisor - 1 + b = d * (height + width) + a = height * width + c = d**2 - max_upper_area + + # Calculate scale boundaries using quadratic equation + min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied + max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding + + # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area + # Use binary search-like iteration to find this scale + find_it = False + for i in range(100): + scale = max_scale - (max_scale - min_scale) * i / 100 + new_height, new_width = int(height * scale), int(width * scale) + + # Pad to make dimensions divisible by 64 + pad_height = (64 - new_height % 64) % 64 + pad_width = (64 - new_width % 64) % 64 + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + + padded_height, padded_width = new_height + pad_height, new_width + pad_width + + if padded_height * padded_width <= max_upper_area: + find_it = True + break + + if find_it: + return padded_height, padded_width + else: + # Fallback: calculate target dimensions based on aspect ratio and divisor alignment + aspect_ratio = width / height + target_width = int( + (target_area * aspect_ratio)**0.5 // divisor * divisor) + target_height = int( + (target_area / aspect_ratio)**0.5 // divisor * divisor) + + # Ensure the result is not larger than the original resolution + if target_width >= width or target_height >= height: + target_width = int(width // divisor * divisor) + target_height = int(height // divisor * divisor) + + return target_height, target_width + +def aspect_ratio_resize(image, pipe, max_area): + height, width = get_size_less_than_area(image.size[1], image.size[0], target_area=max_area) image = image.resize((width, height)) return image, height, width -first_frame, height, width = aspect_ratio_resize(first_frame, pipe) +first_frame, height, width = aspect_ratio_resize(first_frame, pipe, 480*832) -prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." +prompt = "Einstein singing a song." output = pipe( image=first_frame, audio=audio, sampling_rate=sampling_rate, prompt=prompt, height=height, width=width, num_frames_per_chunk=81, - #pose_video=pose_video + #pose_video=pose_path ).frames[0] export_to_video(output, "output.mp4", fps=16) ``` diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 35b7b40149ff..d062f9871ed5 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -139,6 +139,8 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) + attention_kwargs = {"deterministic": True} + # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: @@ -156,6 +158,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, + attention_kwargs=attention_kwargs, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -168,6 +171,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, + attention_kwargs=attention_kwargs, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index efc7a318b483..eda4846735c4 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -57,15 +57,11 @@ >>> import numpy as np >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline >>> from diffusers.utils import export_to_video, load_image, load_audio, load_video - >>> from transformers import Wav2Vec2ForCTC >>> # Available models: Wan-AI/Wan2.2-S2V-14B-Diffusers >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) - >>> audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) - >>> pipe = WanSpeechToVideoPipeline.from_pretrained( - ... model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 - ... ) + >>> pipe = WanSpeechToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> first_frame = load_image( @@ -89,6 +85,7 @@ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ... ) >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + >>> audio = load_audio(...) >>> output = pipe( ... prompt=prompt, @@ -100,6 +97,7 @@ ... height=height, ... width=width, ... num_frames_per_chunk=81, + ... guidance_scale=5.0, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=16) ``` @@ -905,13 +903,10 @@ def __call__( audio_embeds_input = audio_embeds[..., left_idx:right_idx] motion_latents_input = motion_latents.to(transformer_dtype).clone() - # 4. Prepare timesteps - self.scheduler = UniPCMultistepScheduler.from_pretrained( - "tolgacangoz/Wan2.2-S2V-14B-Diffusers", subfolder="scheduler" - ) + # 4. Prepare timesteps by resetting scheduler in each chunk self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) From f5439e190e91f047ad62b4d8ce16a4a3986ab34b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 08:21:49 +0300 Subject: [PATCH 066/131] up --- src/diffusers/models/transformers/transformer_wan_s2v.py | 4 ---- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 8 +++++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index d062f9871ed5..35b7b40149ff 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -139,8 +139,6 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) - attention_kwargs = {"deterministic": True} - # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: @@ -158,7 +156,6 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, - attention_kwargs=attention_kwargs, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -171,7 +168,6 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, - attention_kwargs=attention_kwargs, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index eda4846735c4..e94b55e40cef 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -57,11 +57,15 @@ >>> import numpy as np >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline >>> from diffusers.utils import export_to_video, load_image, load_audio, load_video + >>> from transformers import Wav2Vec2ForCTC >>> # Available models: Wan-AI/Wan2.2-S2V-14B-Diffusers >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) - >>> pipe = WanSpeechToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) + >>> pipe = WanSpeechToVideoPipeline.from_pretrained( + ... model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 + ... ) >>> pipe.to("cuda") >>> first_frame = load_image( @@ -85,7 +89,6 @@ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ... ) >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" - >>> audio = load_audio(...) >>> output = pipe( ... prompt=prompt, @@ -97,7 +100,6 @@ ... height=height, ... width=width, ... num_frames_per_chunk=81, - ... guidance_scale=5.0, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=16) ``` From d72f54989819380082b37f5d99cd3a399095a652 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 09:16:10 +0300 Subject: [PATCH 067/131] style --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index e94b55e40cef..dc944d14b40d 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -908,7 +908,7 @@ def __call__( # 4. Prepare timesteps by resetting scheduler in each chunk self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) From 5eea0c766e5f46f7eaa8d56fafd2745ca9c4291d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 14:24:28 +0300 Subject: [PATCH 068/131] Enhance load_video function with frame sampling options and reverse playback --- src/diffusers/utils/loading_utils.py | 47 ++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index 781dab3afabc..e8a89010b944 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -59,6 +59,9 @@ def load_image( def load_video( video: str, convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, + n_frames: Optional[int] = None, + target_fps: Optional[int] = None, + reverse: bool = False, ) -> List[PIL.Image.Image]: """ Loads `video` to a list of PIL Image. @@ -69,6 +72,13 @@ def load_video( convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): A conversion method to apply to the video after loading it. When set to `None` the images will be converted to "RGB". + n_frames (`int`, *optional*): + Number of frames to sample from the video. If None, all frames are loaded. + target_fps (`int`, *optional*): + Target sampling frame rate. If None, uses original frame rate. + reverse (`bool`, *optional*): + If True, samples frames starting from the beginning of the video; if False, samples frames starting from the end. + Defaults to False. Returns: `List[PIL.Image.Image]`: @@ -127,9 +137,40 @@ def load_video( ) with imageio.get_reader(video) as reader: - # Read all frames - for frame in reader: - pil_images.append(PIL.Image.fromarray(frame)) + # Determine which frames to sample + if n_frames is not None and target_fps is not None: + # Get video metadata + total_frames = reader.count_frames() + original_fps = reader.get_meta_data().get('fps') + + # Calculate sampling interval based on target fps + interval = max(1, round(original_fps / target_fps)) + required_span = (n_frames - 1) * interval + + if reverse: + start_frame = 0 + else: + start_frame = max(0, total_frames - required_span - 1) + + # Generate sampling indices + sampled_indices = [] + for i in range(n_frames): + indice = start_frame + i * interval + if indice >= total_frames: + break + sampled_indices.append(int(indice)) + + # Read specific frames + for idx in sampled_indices: + try: + frame = reader.get_data(idx) + pil_images.append(PIL.Image.fromarray(frame)) + except IndexError: + break + else: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) if was_tempfile_created: os.remove(video_path) From 33e5b67604bc4d7977e9e4b316a7a59682c38a8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 14:44:45 +0300 Subject: [PATCH 069/131] Refactor load_pose_condition method to simplify pose video handling and remove unused read_last_n_frames method --- .../pipelines/wan/pipeline_wan_s2v.py | 66 ++++--------------- 1 file changed, 12 insertions(+), 54 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index dc944d14b40d..46e3778e2d0d 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -581,31 +581,18 @@ def prepare_latents( else: return latents - def load_pose_condition( - self, pose_video, num_chunks, num_frames_per_chunk, size, sampling_fps, latents_mean, latents_std - ): + def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size, latents_mean, latents_std): HEIGHT, WIDTH = size if pose_video is not None: - pose_seq = self.read_last_n_frames( - pose_video, n_frames=num_frames_per_chunk * num_chunks, target_fps=sampling_fps, reverse=True - ) - - resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) - crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) - - cond_tensor = torch.from_numpy(pose_seq) - cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0 - cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute(1, 0, 2, 3).unsqueeze(0) - - padding_frame_num = num_chunks * num_frames_per_chunk - cond_tensor.shape[2] - cond_tensor = torch.cat([cond_tensor, -torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH])], dim=2) + padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2] + pose_video = torch.cat([pose_video, -torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH])], dim=2) - cond_tensors = torch.chunk(cond_tensor, num_chunks, dim=2) + pose_video = torch.chunk(pose_video, num_chunks, dim=2) else: - cond_tensors = [-torch.ones([1, 3, num_frames_per_chunk, HEIGHT, WIDTH])] + pose_video = [-torch.ones([1, 3, num_frames_per_chunk, HEIGHT, WIDTH])] pose_condition = [] - for cond in cond_tensors: + for cond in pose_video: cond = torch.cat([cond[:, :, 0:1], cond], dim=2) cond = cond.to(dtype=self.dtype, device=self._execution_device) cond_lat = retrieve_latents(self.vae.encode(cond), sample_mode="argmax")[:, :, 1:] @@ -614,40 +601,6 @@ def load_pose_condition( return pose_condition - def read_last_n_frames(self, video_path, n_frames, target_fps=16, reverse=False): - """ - Read the last `n_frames` from a video at the specified frame rate. - - Parameters: - video_path (str): Path to the video file. - n_frames (int): Number of frames to read. - target_fps (int, optional): Target sampling frame rate. Defaults to 16. - reverse (bool, optional): Whether to read frames in reverse order. - If True, reads the first `n_frames` instead of the last ones. - - Returns: - np.ndarray: A NumPy array of shape [n_frames, H, W, 3], representing the sampled video frames. - """ - vr = VideoReader(video_path) - original_fps = vr.get_avg_fps() - total_frames = len(vr) - - interval = max(1, round(original_fps / target_fps)) - - required_span = (n_frames - 1) * interval - - start_frame = max(0, total_frames - required_span - 1) if not reverse else 0 - - sampled_indices = [] - for i in range(n_frames): - indice = start_frame + i * interval - if indice >= total_frames: - break - else: - sampled_indices.append(indice) - - return vr.get_batch(sampled_indices).asnumpy() - @property def guidance_scale(self): return self._guidance_scale @@ -872,6 +825,11 @@ def __call__( num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if pose_video is not None: + pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + all_latents = [] for r in range(num_chunks): latents_outputs = self.prepare_latents( @@ -900,7 +858,7 @@ def __call__( with torch.no_grad(): left_idx = r * num_frames_per_chunk right_idx = r * num_frames_per_chunk + num_frames_per_chunk - pose_latents = pose_condition[r] if pose_video else pose_condition[0] + pose_latents = pose_condition[r] if pose_video else pose_condition[0] * 0 pose_latents = pose_latents.to(dtype=transformer_dtype, device=device) audio_embeds_input = audio_embeds[..., left_idx:right_idx] motion_latents_input = motion_latents.to(transformer_dtype).clone() From 9562c2605bc13bf1c829db48430fb1243fc8401a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 14:50:43 +0300 Subject: [PATCH 070/131] style --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 13 ++++++------- src/diffusers/utils/loading_utils.py | 8 ++++---- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 46e3778e2d0d..fac81a46e565 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -20,9 +20,7 @@ import regex as re import torch import torch.nn.functional as F -from decord import VideoReader from PIL import Image -from torchvision import transforms from transformers import AutoTokenizer, UMT5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor from ...audio_processor import PipelineAudioInput @@ -567,7 +565,7 @@ def prepare_latents( motion_pixels = torch.zeros([1, 3, self.motion_frames, height, width], dtype=self.vae.dtype, device=device) # Get pose condition input if needed pose_condition = self.load_pose_condition( - pose_video, num_chunks, num_frames_per_chunk, (height, width), sampling_fps, latents_mean, latents_std + pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std ) # Encode motion latents if init_first_frame: @@ -581,15 +579,16 @@ def prepare_latents( else: return latents - def load_pose_condition(self, pose_video, num_chunks, num_frames_per_chunk, size, latents_mean, latents_std): - HEIGHT, WIDTH = size + def load_pose_condition( + self, pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std + ): if pose_video is not None: padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2] - pose_video = torch.cat([pose_video, -torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH])], dim=2) + pose_video = torch.cat([pose_video, -torch.ones([1, 3, padding_frame_num, height, width])], dim=2) pose_video = torch.chunk(pose_video, num_chunks, dim=2) else: - pose_video = [-torch.ones([1, 3, num_frames_per_chunk, HEIGHT, WIDTH])] + pose_video = [-torch.ones([1, 3, num_frames_per_chunk, height, width])] pose_condition = [] for cond in pose_video: diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index e8a89010b944..29e8a7855fdd 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -77,8 +77,8 @@ def load_video( target_fps (`int`, *optional*): Target sampling frame rate. If None, uses original frame rate. reverse (`bool`, *optional*): - If True, samples frames starting from the beginning of the video; if False, samples frames starting from the end. - Defaults to False. + If True, samples frames starting from the beginning of the video; if False, samples frames starting from + the end. Defaults to False. Returns: `List[PIL.Image.Image]`: @@ -141,8 +141,8 @@ def load_video( if n_frames is not None and target_fps is not None: # Get video metadata total_frames = reader.count_frames() - original_fps = reader.get_meta_data().get('fps') - + original_fps = reader.get_meta_data().get("fps") + # Calculate sampling interval based on target fps interval = max(1, round(original_fps / target_fps)) required_span = (n_frames - 1) * interval From 8c4c0183f00b01a2ef97685717234f9bb4904921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 16:36:14 +0300 Subject: [PATCH 071/131] Update parameter descriptions and simplify tensor operations in WanSpeechToVideoPipeline --- src/diffusers/models/transformers/transformer_wan_s2v.py | 6 +++--- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 35b7b40149ff..2c2acc273dce 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -186,9 +186,9 @@ class AdaLayerNorm(nn.Module): Parameters: embedding_dim (`int`): The size of each embedding vector. - output_dim (`int`, *optional*): - norm_elementwise_affine (`bool`, defaults to `False): - norm_eps (`bool`, defaults to `False`): + output_dim (`int`, *optional*): Output dimension for the layer. + norm_elementwise_affine (`bool`, defaults to `False`): Whether to use elementwise affine in LayerNorm. + norm_eps (`float`, defaults to `1e-5`): Epsilon value for LayerNorm. """ def __init__( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index fac81a46e565..be71a74f6992 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -343,7 +343,7 @@ def encode_audio( batch_audio_eb.append(frame_audio_embed) audio_embed_bucket = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) - audio_embed_bucket = audio_embed_bucket.to(device, self.dtype) + audio_embed_bucket = audio_embed_bucket.to(device) audio_embed_bucket = audio_embed_bucket.unsqueeze(0) if len(audio_embed_bucket.shape) == 3: audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) @@ -857,7 +857,7 @@ def __call__( with torch.no_grad(): left_idx = r * num_frames_per_chunk right_idx = r * num_frames_per_chunk + num_frames_per_chunk - pose_latents = pose_condition[r] if pose_video else pose_condition[0] * 0 + pose_latents = pose_condition[r] if pose_video is not None else pose_condition[0] * 0 pose_latents = pose_latents.to(dtype=transformer_dtype, device=device) audio_embeds_input = audio_embeds[..., left_idx:right_idx] motion_latents_input = motion_latents.to(transformer_dtype).clone() From 12facf8ce58a45b9f23cd64b178424479ce0dbf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 17:54:54 +0300 Subject: [PATCH 072/131] Propose to vectorize to assume each element in a batch standard, same --- .../transformers/transformer_wan_s2v.py | 154 ++++++++---------- .../pipelines/wan/pipeline_wan_s2v.py | 16 +- 2 files changed, 80 insertions(+), 90 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 2c2acc273dce..cc1157f8fad8 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -415,88 +415,81 @@ def __init__( ) def forward(self, motion_latents, add_last_motion=2): - mot = [] - mot_remb = [] - for motion_latent in motion_latents: - latent_height, latent_width = motion_latent.shape[2], motion_latent.shape[3] - padd_latent = torch.zeros(16, self.zip_frame_buckets.sum(), latent_height, latent_width).to( - device=motion_latent.device, dtype=motion_latent.dtype - ) - overlap_frame = min(padd_latent.shape[1], motion_latent.shape[1]) - if overlap_frame > 0: - padd_latent[:, -overlap_frame:] = motion_latent[:, -overlap_frame:] - - if add_last_motion < 2 and self.drop_mode != "drop": - zero_end_frame = self.zip_frame_buckets[: len(self.zip_frame_buckets) - add_last_motion - 1].sum() - padd_latent[:, -zero_end_frame:] = 0 - - padd_latent = padd_latent.unsqueeze(0) - clean_latents_4x, clean_latents_2x, clean_latents_post = padd_latent[ - :, :, -self.zip_frame_buckets.sum() :, :, : - ].split(list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2, 1 - - # Patchify - clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) - clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) - clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) - - if add_last_motion < 2 and self.drop_mode == "drop": - clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post - clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x - - motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) - - # RoPE - start_time_id = -(self.zip_frame_buckets[:1].sum()) - end_time_id = start_time_id + self.zip_frame_buckets[0] - grid_sizes = ( - [] - if add_last_motion < 2 and self.drop_mode == "drop" - else [ - [ - torch.tensor([start_time_id, 0, 0]).unsqueeze(0), - torch.tensor([end_time_id, latent_height // 2, latent_width // 2]).unsqueeze(0), - torch.tensor([self.zip_frame_buckets[0], latent_height // 2, latent_width // 2]).unsqueeze(0), - ] - ] - ) - - start_time_id = -(self.zip_frame_buckets[:2].sum()) - end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 - grid_sizes_2x = ( - [] - if add_last_motion < 1 and self.drop_mode == "drop" - else [ - [ - torch.tensor([start_time_id, 0, 0]).unsqueeze(0), - torch.tensor([end_time_id, latent_height // 4, latent_width // 4]).unsqueeze(0), - torch.tensor([self.zip_frame_buckets[1], latent_height // 2, latent_width // 2]).unsqueeze(0), - ] + latent_height, latent_width = motion_latents.shape[3], motion_latents.shape[4] + padd_latent = torch.zeros(motion_latents.shape[0], 16, self.zip_frame_buckets.sum(), latent_height, latent_width).to( + device=motion_latents.device, dtype=motion_latents.dtype) + overlap_frame = min(padd_latent.shape[2], motion_latents.shape[2]) + if overlap_frame > 0: + padd_latent[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[: len(self.zip_frame_buckets) - add_last_motion - 1].sum() + padd_latent[:, :, -zero_end_frame:] = 0 + + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_latent[ + :, :, -self.zip_frame_buckets.sum() :, :, : + ].split(list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2, 1 + + # Patchify + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + # RoPE + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = ( + [] + if add_last_motion < 2 and self.drop_mode == "drop" + else [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 2, latent_width // 2]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[0], latent_height // 2, latent_width // 2]).unsqueeze(0), ] - ) + ] + ) - start_time_id = -(self.zip_frame_buckets[:3].sum()) - end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 - grid_sizes_4x = [ + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = ( + [] + if add_last_motion < 1 and self.drop_mode == "drop" + else [ [ torch.tensor([start_time_id, 0, 0]).unsqueeze(0), - torch.tensor([end_time_id, latent_height // 8, latent_width // 8]).unsqueeze(0), - torch.tensor([self.zip_frame_buckets[2], latent_height // 2, latent_width // 2]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 4, latent_width // 4]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[1], latent_height // 2, latent_width // 2]).unsqueeze(0), ] ] + ) + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 8, latent_width // 8]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[2], latent_height // 2, latent_width // 2]).unsqueeze(0), + ] + ] - grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x - motion_rope_emb = self.rope( - motion_lat.detach().view( - 1, motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads - ), - grid_sizes=grid_sizes, - ) + motion_rope_emb = self.rope( + motion_lat.detach().view( + motion_lat.shape[0], motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads + ), + grid_sizes=grid_sizes, + ) - mot.append(motion_lat) - mot_remb.append(motion_rope_emb) - return mot, mot_remb + return motion_lat, motion_rope_emb class WanTimeTextAudioPoseEmbedding(nn.Module): @@ -972,16 +965,11 @@ def inject_motion( mot, mot_remb = self.process_motion(motion_latents, drop_motion_frames) if len(mot) > 0: - hidden_states = [torch.cat([u.unsqueeze(0), m], dim=1) for u, m in zip(hidden_states, mot)] - seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) - rope_embs = [torch.cat([u.unsqueeze(0), m], dim=1) for u, m in zip(rope_embs, mot_remb)] - mask_input = [ - torch.cat( - [m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], device=m.device, dtype=m.dtype)], - dim=1, - ) - for m, u in zip(mask_input, hidden_states) - ] + hidden_states = torch.cat([hidden_states, mot], dim=1) + seq_lens = seq_lens + torch.tensor([mot.shape[1]], dtype=torch.long) + rope_embs = torch.cat([rope_embs, mot_remb], dim=1) + mask_input = torch.cat([mask_input, 2 * torch.ones([1, hidden_states.shape[1] - mask_input.shape[1]], + device=mask_input.device, dtype=mask_input.dtype)], dim=1) return hidden_states, seq_lens, rope_embs, mask_input def after_transformer_block( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index be71a74f6992..96f9813b566d 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -590,13 +590,15 @@ def load_pose_condition( else: pose_video = [-torch.ones([1, 3, num_frames_per_chunk, height, width])] - pose_condition = [] - for cond in pose_video: - cond = torch.cat([cond[:, :, 0:1], cond], dim=2) - cond = cond.to(dtype=self.dtype, device=self._execution_device) - cond_lat = retrieve_latents(self.vae.encode(cond), sample_mode="argmax")[:, :, 1:] - cond_lat = (cond_lat - latents_mean) * latents_std - pose_condition.append(cond_lat) + # Vectorized processing: concatenate all chunks along batch dimension + all_poses = torch.cat([ + torch.cat([cond[:, :, 0:1], cond], dim=2) + for cond in pose_video + ], dim=0) # Shape: [num_chunks, 3, num_frames_per_chunk+1, height, width] + + all_poses = all_poses.to(dtype=self.vae.dtype, device=self.vae.device) + pose_condition = retrieve_latents(self.vae.encode(all_poses), sample_mode="argmax")[:, :, 1:] + pose_condition = (pose_condition - latents_mean) * latents_std return pose_condition From 4ab554769d5da80744a50586af8b7731ecdbf67f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 18:01:20 +0300 Subject: [PATCH 073/131] style --- .../transformers/transformer_wan_s2v.py | 33 ++++++++++++++----- .../pipelines/wan/pipeline_wan_s2v.py | 23 ++++++++----- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index cc1157f8fad8..a538191834b5 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -416,8 +416,11 @@ def __init__( def forward(self, motion_latents, add_last_motion=2): latent_height, latent_width = motion_latents.shape[3], motion_latents.shape[4] - padd_latent = torch.zeros(motion_latents.shape[0], 16, self.zip_frame_buckets.sum(), latent_height, latent_width).to( - device=motion_latents.device, dtype=motion_latents.dtype) + padd_latent = torch.zeros( + (motion_latents.shape[0], 16, self.zip_frame_buckets.sum(), latent_height, latent_width), + device=motion_latents.device, + dtype=motion_latents.dtype, + ) overlap_frame = min(padd_latent.shape[2], motion_latents.shape[2]) if overlap_frame > 0: padd_latent[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:] @@ -484,7 +487,10 @@ def forward(self, motion_latents, add_last_motion=2): motion_rope_emb = self.rope( motion_lat.detach().view( - motion_lat.shape[0], motion_lat.shape[1], self.num_attention_heads, self.inner_dim // self.num_attention_heads + motion_lat.shape[0], + motion_lat.shape[1], + self.num_attention_heads, + self.inner_dim // self.num_attention_heads, ), grid_sizes=grid_sizes, ) @@ -514,7 +520,6 @@ def __init__( self.causal_audio_encoder = CausalAudioEncoder( dim=audio_embed_dim, out_dim=dim, num_audio_token=4, need_global=enable_adain ) - self.pose_embedder = nn.Conv3d(pose_embed_dim, dim, kernel_size=patch_size, stride=patch_size) def forward( @@ -618,8 +623,10 @@ def forward( # Loop over samples output = torch.view_as_complex( torch.zeros( - (batch_size, S, self.num_attention_heads, self.attention_head_dim // 2, 2), device=hidden_states.device - ).to(torch.float64) + (batch_size, S, self.num_attention_heads, self.attention_head_dim // 2, 2), + device=hidden_states.device, + dtype=torch.float64, + ) ) seq_bucket = [0] for g in grids: @@ -968,8 +975,18 @@ def inject_motion( hidden_states = torch.cat([hidden_states, mot], dim=1) seq_lens = seq_lens + torch.tensor([mot.shape[1]], dtype=torch.long) rope_embs = torch.cat([rope_embs, mot_remb], dim=1) - mask_input = torch.cat([mask_input, 2 * torch.ones([1, hidden_states.shape[1] - mask_input.shape[1]], - device=mask_input.device, dtype=mask_input.dtype)], dim=1) + mask_input = torch.cat( + [ + mask_input, + 2 + * torch.ones( + [1, hidden_states.shape[1] - mask_input.shape[1]], + device=mask_input.device, + dtype=mask_input.dtype, + ), + ], + dim=1, + ) return hidden_states, seq_lens, rope_embs, mask_input def after_transformer_block( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 96f9813b566d..043939f83acc 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -516,7 +516,6 @@ def prepare_latents( pose_video: Optional[List[Image.Image]] = None, init_first_frame: bool = False, num_chunks: int = 1, - sampling_fps: int = 16, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]]: num_latent_frames = ( num_frames_per_chunk + 3 + self.motion_frames @@ -584,19 +583,26 @@ def load_pose_condition( ): if pose_video is not None: padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2] - pose_video = torch.cat([pose_video, -torch.ones([1, 3, padding_frame_num, height, width])], dim=2) + pose_video = pose_video.to(dtype=self.vae.dtype, device=self.vae.device) + pose_video = torch.cat( + [ + pose_video, + -torch.ones( + [1, 3, padding_frame_num, height, width], dtype=self.vae.dtype, device=self.vae.device + ), + ], + dim=2, + ) pose_video = torch.chunk(pose_video, num_chunks, dim=2) else: pose_video = [-torch.ones([1, 3, num_frames_per_chunk, height, width])] # Vectorized processing: concatenate all chunks along batch dimension - all_poses = torch.cat([ - torch.cat([cond[:, :, 0:1], cond], dim=2) - for cond in pose_video - ], dim=0) # Shape: [num_chunks, 3, num_frames_per_chunk+1, height, width] + all_poses = torch.cat( + [torch.cat([cond[:, :, 0:1], cond], dim=2) for cond in pose_video], dim=0 + ) # Shape: [num_chunks, 3, num_frames_per_chunk+1, height, width] - all_poses = all_poses.to(dtype=self.vae.dtype, device=self.vae.device) pose_condition = retrieve_latents(self.vae.encode(all_poses), sample_mode="argmax")[:, :, 1:] pose_condition = (pose_condition - latents_mean) * latents_std @@ -820,7 +826,7 @@ def __call__( num_chunks = num_chunks_audio audio_embeds = audio_embeds.to(transformer_dtype) - latent_motion_frames = (self.motion_frames + 3) // 4 + latent_motion_frames = (self.motion_frames + 3) // self.vae_scale_factor_temporal # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim @@ -848,7 +854,6 @@ def __call__( pose_video, init_first_frame, num_chunks, - sampling_fps, ) if r == 0: From 4e5f3570ae9dacf19804a41ce7688ad5ceebb6ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 19:44:06 +0300 Subject: [PATCH 074/131] Fix pose_video tensor initialization to use correct dtype and device --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 043939f83acc..57f17885a7c4 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -596,7 +596,7 @@ def load_pose_condition( pose_video = torch.chunk(pose_video, num_chunks, dim=2) else: - pose_video = [-torch.ones([1, 3, num_frames_per_chunk, height, width])] + pose_video = [-torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=self.vae.dtype, device=self.vae.device)] # Vectorized processing: concatenate all chunks along batch dimension all_poses = torch.cat( From b9224b92c92879483b49f8b9e7d381cd79838f4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 19:53:54 +0300 Subject: [PATCH 075/131] up --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index a538191834b5..881f9d1c6dff 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -951,7 +951,7 @@ def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, ad flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) if drop_motion_frames: - return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] + return flattern_mot[:,:,:0], mot_remb[:,:,:0] else: return flattern_mot, mot_remb From bcf71dbe8809ea65dd9d92e192de950fb7b4e416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 19:56:33 +0300 Subject: [PATCH 076/131] =?UTF-8?q?=C4=B1p?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 881f9d1c6dff..980e2cea0d48 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -951,7 +951,7 @@ def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, ad flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) if drop_motion_frames: - return flattern_mot[:,:,:0], mot_remb[:,:,:0] + return flattern_mot[:, :0], mot_remb[:, :0] else: return flattern_mot, mot_remb From c248b6dc829d69a878efa55f4c17c85f61ea8f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 19:59:17 +0300 Subject: [PATCH 077/131] fix --- src/diffusers/models/transformers/transformer_wan_s2v.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 980e2cea0d48..50ea72b33872 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -978,9 +978,8 @@ def inject_motion( mask_input = torch.cat( [ mask_input, - 2 - * torch.ones( - [1, hidden_states.shape[1] - mask_input.shape[1]], + 2 * torch.ones( + [mask_input.shape[0], 1, hidden_states.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype, ), From 206bbaabd754386d3890427690a5ac497c9ce023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 8 Sep 2025 20:04:01 +0300 Subject: [PATCH 078/131] up --- .../models/transformers/transformer_wan_s2v.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 50ea72b33872..cf8fbc364524 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -979,12 +979,12 @@ def inject_motion( [ mask_input, 2 * torch.ones( - [mask_input.shape[0], 1, hidden_states.shape[1] - mask_input.shape[1]], + [mask_input.shape[0], 1, hidden_states.shape[1] - mask_input.shape[2]], device=mask_input.device, dtype=mask_input.dtype, ), ], - dim=1, + dim=2, ) return hidden_states, seq_lens, rope_embs, mask_input @@ -1128,9 +1128,9 @@ def forward( add_last_motion, ) - hidden_states = torch.cat(hidden_states) - rotary_emb = torch.cat(rotary_emb) - mask_input = torch.cat(mask_input) + #hidden_states = torch.cat(hidden_states) + #rotary_emb = torch.cat(rotary_emb) + #mask_input = torch.cat(mask_input) hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype) From a126570b8979cd1b51e790db62c1ce3da0dce075 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 9 Sep 2025 09:19:39 +0300 Subject: [PATCH 079/131] Fix mask_input tensor shape and dimension in WanS2VTransformer3DModel --- .../models/transformers/transformer_wan_s2v.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index cf8fbc364524..d7cdc631bdd5 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -979,12 +979,12 @@ def inject_motion( [ mask_input, 2 * torch.ones( - [mask_input.shape[0], 1, hidden_states.shape[1] - mask_input.shape[2]], + [1, hidden_states.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype, ), ], - dim=2, + dim=1, ) return hidden_states, seq_lens, rope_embs, mask_input @@ -1111,11 +1111,7 @@ def forward( # Initialize masks to indicate noisy latent, image latent, and motion latent. # However, at this point, only the first two (noisy and image latents) are marked; # the marking of motion latent will be implemented inside `inject_motion`. - mask_input = ( - torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device) - .unsqueeze(0) - .repeat(batch_size, 1, 1) - ) + mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device) mask_input[:, :, original_sequence_length:] = 1 hidden_states, sequence_length, rotary_emb, mask_input = self.inject_motion( From bd0b72ed7f31db3b9c18b6d71faf4b415a56e3cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 9 Sep 2025 09:29:13 +0300 Subject: [PATCH 080/131] Fix mask_input tensor indexing in WanS2VTransformer3DModel --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index d7cdc631bdd5..605f83711cfc 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -1112,7 +1112,7 @@ def forward( # However, at this point, only the first two (noisy and image latents) are marked; # the marking of motion latent will be implemented inside `inject_motion`. mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device) - mask_input[:, :, original_sequence_length:] = 1 + mask_input[:, original_sequence_length:] = 1 hidden_states, sequence_length, rotary_emb, mask_input = self.inject_motion( hidden_states, From 29bddb5f0ce7aa470e64c9eaf618fafa5f0b68a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 10 Sep 2025 09:44:23 +0300 Subject: [PATCH 081/131] up docs --- docs/source/en/api/pipelines/wan.md | 23 ++-- .../transformers/transformer_wan_s2v.py | 7 +- .../pipelines/wan/pipeline_wan_s2v.py | 125 +++++++++++++----- 3 files changed, 111 insertions(+), 44 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 495be3ff880d..0f63e099bc67 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -244,7 +244,7 @@ export_to_video(output, "output.mp4", fps=16) *Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.* -The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, and an audio. +The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, an audio, and a pose video. @@ -255,6 +255,9 @@ import torch from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline from diffusers.utils import export_to_video, load_image, load_audio, load_video from transformers import Wav2Vec2ForCTC +import requests +from PIL import Image +from io import BytesIO model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" @@ -265,9 +268,13 @@ pipe = WanSpeechToVideoPipeline.from_pretrained( ) pipe.to("cuda") -first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png") -audio, sampling_rate = load_audio("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png") -#pose_path = E.g., download from "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" +headers = {"User-Agent": "Mozilla/5.0"} +url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg" +resp = requests.get(url, headers=headers, timeout=30) +image = Image.open(BytesIO(resp.content)) + +audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3") +#pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" def get_size_less_than_area(height, width, @@ -334,14 +341,14 @@ def aspect_ratio_resize(image, pipe, max_area): image = image.resize((width, height)) return image, height, width -first_frame, height, width = aspect_ratio_resize(first_frame, pipe, 480*832) +image, height, width = aspect_ratio_resize(first_frame, pipe, 480*832) prompt = "Einstein singing a song." output = pipe( - image=first_frame, audio=audio, sampling_rate=sampling_rate, - prompt=prompt, height=height, width=width, num_frames_per_chunk=81, - #pose_video=pose_path + prompt=prompt, image=image, audio=audio, sampling_rate=sampling_rate, + height=height, width=width, num_frames_per_chunk=81, + #pose_video_path_or_url=pose_video_path_or_url, ).frames[0] export_to_video(output, "output.mp4", fps=16) ``` diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 605f83711cfc..886a2e8edf28 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -978,7 +978,8 @@ def inject_motion( mask_input = torch.cat( [ mask_input, - 2 * torch.ones( + 2 + * torch.ones( [1, hidden_states.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype, @@ -1124,10 +1125,6 @@ def forward( add_last_motion, ) - #hidden_states = torch.cat(hidden_states) - #rotary_emb = torch.cat(rotary_emb) - #mask_input = torch.cat(mask_input) - hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype) if self.config.zero_timestep: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 57f17885a7c4..1dca8b8e4094 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -29,7 +29,7 @@ from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanS2VTransformer3DModel from ...schedulers import UniPCMultistepScheduler -from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils import is_ftfy_available, is_torch_xla_available, load_video, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline @@ -51,53 +51,108 @@ EXAMPLE_DOC_STRING = """ Examples: ```python + >>> import numpy as np, math, requests >>> import torch - >>> import numpy as np >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline - >>> from diffusers.utils import export_to_video, load_image, load_audio, load_video + >>> from diffusers.utils import export_to_video, load_audio >>> from transformers import Wav2Vec2ForCTC + >>> from PIL import Image + >>> from io import BytesIO - >>> # Available models: Wan-AI/Wan2.2-S2V-14B-Diffusers >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" - >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) >>> audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) >>> pipe = WanSpeechToVideoPipeline.from_pretrained( ... model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 ... ) >>> pipe.to("cuda") - >>> first_frame = load_image( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" - ... ) - >>> audio, sampling_rate = load_audio( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" - ... ) - >>> pose_video = load_video( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_pose_video.mp4" - ... ) + >>> headers = {"User-Agent": "Mozilla/5.0"} + >>> url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg" + >>> resp = requests.get(url, headers=headers, timeout=30) + >>> image = Image.open(BytesIO(resp.content)) - >>> max_area = 480 * 832 - >>> aspect_ratio = image.height / image.width - >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] - >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value - >>> image = image.resize((width, height)) - >>> prompt = ( - ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " - ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> audio, sampling_rate = load_audio( + ... "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3" ... ) - >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + >>> # pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" + + + >>> def get_size_less_than_area(height, width, target_area=1024 * 704, divisor=64): + ... if height * width <= target_area: + ... # If the original image area is already less than or equal to the target, + ... # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. + ... max_upper_area = target_area + ... min_scale = 0.1 + ... max_scale = 1.0 + ... else: + ... # Resize to fit within the target area and then pad to multiples of `divisor` + ... max_upper_area = target_area # Maximum allowed total pixel count after padding + ... d = divisor - 1 + ... b = d * (height + width) + ... a = height * width + ... c = d**2 - max_upper_area + + ... # Calculate scale boundaries using quadratic equation + ... min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied + ... max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding + + ... # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area + ... # Use binary search-like iteration to find this scale + ... find_it = False + ... for i in range(100): + ... scale = max_scale - (max_scale - min_scale) * i / 100 + ... new_height, new_width = int(height * scale), int(width * scale) + + ... # Pad to make dimensions divisible by 64 + ... pad_height = (64 - new_height % 64) % 64 + ... pad_width = (64 - new_width % 64) % 64 + ... pad_top = pad_height // 2 + ... pad_bottom = pad_height - pad_top + ... pad_left = pad_width // 2 + ... pad_right = pad_width - pad_left + + ... padded_height, padded_width = new_height + pad_height, new_width + pad_width + + ... if padded_height * padded_width <= max_upper_area: + ... find_it = True + ... break + + ... if find_it: + ... return padded_height, padded_width + ... else: + ... # Fallback: calculate target dimensions based on aspect ratio and divisor alignment + ... aspect_ratio = width / height + ... target_width = int((target_area * aspect_ratio) ** 0.5 // divisor * divisor) + ... target_height = int((target_area / aspect_ratio) ** 0.5 // divisor * divisor) + + ... # Ensure the result is not larger than the original resolution + ... if target_width >= width or target_height >= height: + ... target_width = int(width // divisor * divisor) + ... target_height = int(height // divisor * divisor) + + ... return target_height, target_width + + + >>> def aspect_ratio_resize(image, pipe, max_area): + ... height, width = get_size_less_than_area(image.size[1], image.size[0], target_area=max_area) + ... image = image.resize((width, height)) + ... return image, height, width + + + >>> image, height, width = aspect_ratio_resize(first_frame, pipe, 480 * 832) + + >>> prompt = "Einstein singing a song." >>> output = pipe( ... prompt=prompt, ... image=image, ... audio=audio, ... sampling_rate=sampling_rate, - ... # pose_video=pose_video, - ... negative_prompt=negative_prompt, ... height=height, ... width=width, ... num_frames_per_chunk=81, + ... # pose_video_path_or_url=pose_video_path_or_url, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=16) ``` @@ -596,7 +651,9 @@ def load_pose_condition( pose_video = torch.chunk(pose_video, num_chunks, dim=2) else: - pose_video = [-torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=self.vae.dtype, device=self.vae.device)] + pose_video = [ + -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=self.vae.dtype, device=self.vae.device) + ] # Vectorized processing: concatenate all chunks along batch dimension all_poses = torch.cat( @@ -641,7 +698,7 @@ def __call__( sampling_rate: int, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, - pose_video: Optional[List[Image.Image]] = None, + pose_video_path_or_url: Optional[str] = None, height: int = 480, width: int = 832, num_frames_per_chunk: int = 81, @@ -683,8 +740,8 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - pose_video (`List[Image.Image]`, *optional*): - A list of PIL images representing the pose video to condition the generation on. + pose_video_path_or_url (`str` or `List[str]`, *optional*): + The path or URL to the pose video to condition the generation on. height (`int`, defaults to `480`): The height of the generated video. width (`int`, defaults to `832`): @@ -832,7 +889,13 @@ def __call__( num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - if pose_video is not None: + if pose_video_path_or_url is not None: + pose_video = load_video( + pose_video_path_or_url, + n_frames=num_frames_per_chunk * num_chunks, + target_fps=sampling_fps, + reverse=True, + ) pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to( device, dtype=torch.float32 ) From 746514f592cd699d1c4c2f574c278c154c6b3475 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 10 Sep 2025 12:23:04 +0300 Subject: [PATCH 082/131] Adds video and audio merging functionality in docs --- docs/source/en/api/pipelines/wan.md | 60 ++++++++++++++++++ .../pipelines/wan/pipeline_wan_s2v.py | 61 +++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 0f63e099bc67..c206a6cb02c1 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -351,6 +351,66 @@ output = pipe( #pose_video_path_or_url=pose_video_path_or_url, ).frames[0] export_to_video(output, "output.mp4", fps=16) + +# Lastly, we need to merge the video and audio into a new video, with the duration set to +# the shorter of the two and overwrite the original video file. + +import os, logging, subprocess, shutil + +def merge_video_audio(video_path: str, audio_path: str): + logging.basicConfig(level=logging.INFO) + + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # Create ffmpeg command + command = [ + 'ffmpeg', + '-y', # overwrite + '-i', + video_path, + '-i', + audio_path, + '-c:v', + 'copy', # copy video stream + '-c:a', + 'aac', # use AAC audio encoder + '-b:a', + '192k', # set audio bitrate (optional) + '-map', + '0:v:0', # select the first video stream + '-map', + '1:a:0', # select the first audio stream + '-shortest', # choose the shortest duration + temp_output + ] + + # Execute the command + logging.info("Start merging video and audio...") + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # Check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + logging.error(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + logging.info(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + logging.error(f"merge_video_audio failed with error: {e}") + +merge_video_audio("output.mp4", "audio.mp3") ``` diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 1dca8b8e4094..ad66d9b8fc03 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -155,6 +155,67 @@ ... # pose_video_path_or_url=pose_video_path_or_url, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=16) + + >>> # Lastly, we need to merge the video and audio into a new video, with the duration set to + >>> # the shorter of the two and overwrite the original video file. + + >>> import os, logging, subprocess, shutil + + + >>> def merge_video_audio(video_path: str, audio_path: str): + ... logging.basicConfig(level=logging.INFO) + + ... if not os.path.exists(video_path): + ... raise FileNotFoundError(f"video file {video_path} does not exist") + ... if not os.path.exists(audio_path): + ... raise FileNotFoundError(f"audio file {audio_path} does not exist") + + ... base, ext = os.path.splitext(video_path) + ... temp_output = f"{base}_temp{ext}" + + ... try: + ... # Create ffmpeg command + ... command = [ + ... "ffmpeg", + ... "-y", # overwrite + ... "-i", + ... video_path, + ... "-i", + ... audio_path, + ... "-c:v", + ... "copy", # copy video stream + ... "-c:a", + ... "aac", # use AAC audio encoder + ... "-b:a", + ... "192k", # set audio bitrate (optional) + ... "-map", + ... "0:v:0", # select the first video stream + ... "-map", + ... "1:a:0", # select the first audio stream + ... "-shortest", # choose the shortest duration + ... temp_output, + ... ] + + ... # Execute the command + ... logging.info("Start merging video and audio...") + ... result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + ... # Check result + ... if result.returncode != 0: + ... error_msg = f"FFmpeg execute failed: {result.stderr}" + ... logging.error(error_msg) + ... raise RuntimeError(error_msg) + + ... shutil.move(temp_output, video_path) + ... logging.info(f"Merge completed, saved to {video_path}") + + ... except Exception as e: + ... if os.path.exists(temp_output): + ... os.remove(temp_output) + ... logging.error(f"merge_video_audio failed with error: {e}") + + + >>> merge_video_audio("output.mp4", "audio.mp3") ``` """ From b2e57b8e0c964858f808c78654ce0a6a56e1d077 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 10 Sep 2025 16:59:12 +0300 Subject: [PATCH 083/131] fix: initialize pose_video variable in WanSpeechToVideoPipeline --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index ad66d9b8fc03..b6a937b46179 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -950,6 +950,7 @@ def __call__( num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + pose_video = None if pose_video_path_or_url is not None: pose_video = load_video( pose_video_path_or_url, From f7fbf3635206253120170bc44a5537df28d294ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 12 Sep 2025 18:32:18 +0300 Subject: [PATCH 084/131] Enables passing attention kwargs --- .../models/transformers/transformer_wan_s2v.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 886a2e8edf28..a4156a041eaf 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -101,7 +101,9 @@ def __call__( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, ) -> torch.Tensor: + attention_kwargs = kwargs encoder_hidden_states_img = None if attn.add_k_proj is not None: # 512 is the context length of the text encoder, hardcoded for now @@ -156,6 +158,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, + attention_kwargs=attention_kwargs, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -719,6 +722,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: Tuple[torch.Tensor, torch.Tensor], rotary_emb: torch.Tensor, + attention_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: seg_idx = temb[1].item() seg_idx = min(max(0, seg_idx), hidden_states.shape[1]) @@ -746,7 +750,7 @@ def forward( norm_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states) # 1. Self-attention - attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb, **(attention_kwargs or {})) z = [] for i in range(2): z.append(attn_output[:, seg_idx[i] : seg_idx[i + 1]] * gate_msa[:, i : i + 1]) @@ -755,7 +759,7 @@ def forward( # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) - attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None, **(attention_kwargs or {})) hidden_states = hidden_states + attn_output # 3. Feed-forward @@ -1124,6 +1128,7 @@ def forward( drop_motion_frames, add_last_motion, ) + attention_kwargs = {"max_seqlen_k": sequence_length.item()} hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype) @@ -1148,7 +1153,7 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for block_idx, block in enumerate(self.blocks): hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_kwargs=attention_kwargs ) hidden_states = self.after_transformer_block( block_idx, @@ -1160,7 +1165,7 @@ def forward( ) else: for block_idx, block in enumerate(self.blocks): - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_kwargs=attention_kwargs) hidden_states = self.after_transformer_block( block_idx, hidden_states, From 111085f4eb36e84097da5526d708542910004cdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 12 Sep 2025 18:33:56 +0300 Subject: [PATCH 085/131] Propose flash attention with precomputed max_seqlen_k-only --- src/diffusers/models/attention_dispatch.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f71be7c8ecc0..718f92da85fa 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -602,13 +602,23 @@ def _flash_varlen_attention( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + if max_seqlen_k is not None: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, _) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, max_seqlen_k, attn_mask=attn_mask, device=query.device + ) + ) + max_seqlen_k = seq_len_kv + else: + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) ) - ) else: seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) From 8b58f6359c29960c5e3463738aca500601667bd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 15 Sep 2025 13:39:35 +0300 Subject: [PATCH 086/131] style --- src/diffusers/models/attention_dispatch.py | 10 ++--- .../transformers/transformer_wan_s2v.py | 44 ++++++++++++------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 718f92da85fa..c03ec8f7ad9a 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -602,15 +602,13 @@ def _flash_varlen_attention( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): if max_seqlen_k is not None: seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - - (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, _) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, max_seqlen_k, attn_mask=attn_mask, device=query.device - ) + + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, _) = _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, max_seqlen_k, attn_mask=attn_mask, device=query.device ) max_seqlen_k = seq_len_kv else: diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index a4156a041eaf..67526d928463 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -85,13 +85,13 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t return key_img, value_img -class WanAttnProcessor: +class WanS2VAttnProcessor: _attention_backend = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + "WanS2VAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." ) def __call__( @@ -103,7 +103,6 @@ def __call__( rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: - attention_kwargs = kwargs encoder_hidden_states_img = None if attn.add_k_proj is not None: # 512 is the context length of the text encoder, hardcoded for now @@ -158,7 +157,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, - attention_kwargs=attention_kwargs, + attention_kwargs=kwargs, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -171,6 +170,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, + attention_kwargs=kwargs, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -358,7 +358,7 @@ def __init__( eps=eps, added_kv_proj_dim=added_kv_proj_dim, cross_attention_dim_head=dim // num_heads, - processor=WanAttnProcessor(), + processor=WanS2VAttnProcessor(), ) for _ in range(audio_injector_id) ] @@ -695,7 +695,7 @@ def __init__( dim_head=dim // num_heads, eps=eps, cross_attention_dim_head=None, - processor=WanAttnProcessor(), + processor=WanS2VAttnProcessor(), ) # 2. Cross-attention @@ -706,7 +706,7 @@ def __init__( eps=eps, added_kv_proj_dim=added_kv_proj_dim, cross_attention_dim_head=dim // num_heads, - processor=WanAttnProcessor(), + processor=WanS2VAttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -839,11 +839,11 @@ def __init__( out_channels: int = 16, text_dim: int = 4096, freq_dim: int = 256, - audio_dim: int = 1280, - audio_inject_layers: List[int] = [0, 4, 8, 12, 16, 20, 24, 27], + audio_dim: int = 1024, + audio_inject_layers: List[int] = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], enable_adain: bool = True, adain_mode: str = "attn_norm", - pose_dim: int = 1280, + pose_dim: int = 16, ffn_dim: int = 13824, num_layers: int = 40, cross_attn_norm: bool = True, @@ -851,9 +851,9 @@ def __init__( eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, - enable_framepack: bool = False, + enable_framepack: bool = True, framepack_drop_mode: str = "padd", - add_last_motion: bool = False, + add_last_motion: bool = True, zero_timestep: bool = True, ) -> None: super().__init__() @@ -1015,9 +1015,14 @@ def after_transformer_block( else: attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + attention_kwargs = { + "max_seqlen_k": torch.ones( + attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device + ) + * attn_audio_emb.shape[1] + } residual_out = self.audio_injector.injector[audio_attn_id]( - attn_hidden_states, - attn_audio_emb, + attn_hidden_states, attn_audio_emb, None, None, attention_kwargs ) residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out @@ -1153,7 +1158,12 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for block_idx, block in enumerate(self.blocks): hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_kwargs=attention_kwargs + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_kwargs, ) hidden_states = self.after_transformer_block( block_idx, @@ -1165,7 +1175,9 @@ def forward( ) else: for block_idx, block in enumerate(self.blocks): - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_kwargs=attention_kwargs) + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_kwargs + ) hidden_states = self.after_transformer_block( block_idx, hidden_states, From 66e58b8fb078447e7f48bbd9c7d3541d20ce67ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 15 Sep 2025 13:45:08 +0300 Subject: [PATCH 087/131] Propose to add `FP32RMSNorm` --- .../models/transformers/transformer_wan.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 968a0369c243..a5322b28324d 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -166,6 +166,24 @@ def __new__(cls, *args, **kwargs): return WanAttnProcessor(*args, **kwargs) +class FP32RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + class WanAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = WanAttnProcessor _available_processors = [WanAttnProcessor] @@ -199,14 +217,14 @@ def __init__( torch.nn.Dropout(dropout), ] ) - self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) - self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_q = FP32RMSNorm(dim_head * heads, eps=eps) + self.norm_k = FP32RMSNorm(dim_head * heads, eps=eps) self.add_k_proj = self.add_v_proj = None if added_kv_proj_dim is not None: self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) - self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + self.norm_added_k = FP32RMSNorm(dim_head * heads, eps=eps) self.is_cross_attention = cross_attention_dim_head is not None From f503a26abcc19314d00786f50b3ab6cbdecbd497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 16 Sep 2025 08:31:45 +0300 Subject: [PATCH 088/131] Fix argument unpacking in audio injector call in WanS2VTransformer3DModel --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 67526d928463..6731efa8931e 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -1022,7 +1022,7 @@ def after_transformer_block( * attn_audio_emb.shape[1] } residual_out = self.audio_injector.injector[audio_attn_id]( - attn_hidden_states, attn_audio_emb, None, None, attention_kwargs + attn_hidden_states, attn_audio_emb, None, None, **attention_kwargs ) residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out From 9fe3596a175ac041bf9daaa50903894deab33248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 09:16:17 +0300 Subject: [PATCH 089/131] Remove `FP32RMSNorm` --- .../models/transformers/transformer_wan.py | 24 +++---------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index a5322b28324d..e2c0d3e508de 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -166,24 +166,6 @@ def __new__(cls, *args, **kwargs): return WanAttnProcessor(*args, **kwargs) -class FP32RMSNorm(nn.Module): - def __init__(self, dim, eps=1e-5): - super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return self._norm(x.float()).type_as(x) * self.weight - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - - class WanAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = WanAttnProcessor _available_processors = [WanAttnProcessor] @@ -217,14 +199,14 @@ def __init__( torch.nn.Dropout(dropout), ] ) - self.norm_q = FP32RMSNorm(dim_head * heads, eps=eps) - self.norm_k = FP32RMSNorm(dim_head * heads, eps=eps) + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) self.add_k_proj = self.add_v_proj = None if added_kv_proj_dim is not None: self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) - self.norm_added_k = FP32RMSNorm(dim_head * heads, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) self.is_cross_attention = cross_attention_dim_head is not None From 3542a46ce9440124347f9d7c457afeed92648d1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:17:49 +0300 Subject: [PATCH 090/131] Update src/diffusers/models/transformers/transformer_wan_s2v.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 6731efa8931e..0abced7a24da 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -220,7 +220,7 @@ def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch return x -class CausalConv1d(nn.Module): +class WanS2VCausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): super().__init__() From 6761385730b7af09c6871ecffb0fdb385b750356 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:20:07 +0300 Subject: [PATCH 091/131] Update src/diffusers/models/transformers/transformer_wan_s2v.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 0abced7a24da..8db9425a6525 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -360,7 +360,7 @@ def __init__( cross_attention_dim_head=dim // num_heads, processor=WanS2VAttnProcessor(), ) - for _ in range(audio_injector_id) + for _ in range(num_injection_layers) ] ) From e0b8ce93a8087bf18eb269275f1565c103d3d497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:24:37 +0300 Subject: [PATCH 092/131] Update src/diffusers/models/transformers/transformer_wan_s2v.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 8db9425a6525..42762aae9d37 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -393,8 +393,6 @@ def __init__( ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion patch_size=(1, 2, 2), - *args, - **kwargs, ): super().__init__(*args, **kwargs) self.inner_dim = inner_dim From b8b6709e561d043af1de5c500e5d83507bf8ae86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:03:00 +0300 Subject: [PATCH 093/131] Update src/diffusers/models/transformers/transformer_wan_s2v.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 42762aae9d37..444056ce3276 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -234,7 +234,7 @@ def forward(self, x): return self.conv(x) -class MotionEncoder_tc(nn.Module): +class WanS2VMotionEncoder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads: int, need_global: bool = True): super().__init__() From d2840fc0a1109c5f4f38ee50225cb9e8052f7600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 10:21:34 +0300 Subject: [PATCH 094/131] Update module names --- .../models/transformers/transformer_wan_s2v.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 444056ce3276..e17f935bd132 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -240,12 +240,12 @@ def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads: int, need_ self.num_attention_heads = num_attention_heads self.need_global = need_global - self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_attention_heads, 3, stride=1) + self.conv1_local = WanS2VCausalConv1d(in_dim, hidden_dim // 4 * num_attention_heads, 3, stride=1) if need_global: - self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) + self.conv1_global = WanS2VCausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) self.act = nn.SiLU() - self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) - self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + self.conv2 = WanS2VCausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = WanS2VCausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) if need_global: self.final_linear = nn.Linear(hidden_dim, hidden_dim) @@ -305,7 +305,7 @@ def forward(self, x): class CausalAudioEncoder(nn.Module): def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_audio_token=4, need_global=False): super().__init__() - self.encoder = MotionEncoder_tc( + self.encoder = WanS2VMotionEncoder( in_dim=dim, hidden_dim=out_dim, num_attention_heads=num_audio_token, need_global=need_global ) weight = torch.ones((1, num_layers, 1, 1)) * 0.01 @@ -394,7 +394,7 @@ def __init__( drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion patch_size=(1, 2, 2), ): - super().__init__(*args, **kwargs) + super().__init__() self.inner_dim = inner_dim self.num_attention_heads = num_attention_heads if (inner_dim % num_attention_heads) != 0 or (inner_dim // num_attention_heads) % 2 != 0: From a6c1b272d4b1aca927de5d53fd19d39982ad0455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 10:22:34 +0300 Subject: [PATCH 095/131] Adds `export_to_merged_video_audio` utility Adds a utility function to merge video and audio files using ffmpeg. This simplifies the process of combining audio and video outputs, especially useful in pipelines like WanSpeechToVideoPipeline. The function handles temporary file creation, command execution, and error handling for a more robust merging process. --- .../pipelines/wan/pipeline_wan_s2v.py | 60 +----------------- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/export_utils.py | 62 +++++++++++++++++++ 3 files changed, 65 insertions(+), 59 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index b6a937b46179..512c94a552d4 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -54,7 +54,7 @@ >>> import numpy as np, math, requests >>> import torch >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline - >>> from diffusers.utils import export_to_video, load_audio + >>> from diffusers.utils import export_to_video, load_audio, export_to_merged_video_audio >>> from transformers import Wav2Vec2ForCTC >>> from PIL import Image >>> from io import BytesIO @@ -159,63 +159,7 @@ >>> # Lastly, we need to merge the video and audio into a new video, with the duration set to >>> # the shorter of the two and overwrite the original video file. - >>> import os, logging, subprocess, shutil - - - >>> def merge_video_audio(video_path: str, audio_path: str): - ... logging.basicConfig(level=logging.INFO) - - ... if not os.path.exists(video_path): - ... raise FileNotFoundError(f"video file {video_path} does not exist") - ... if not os.path.exists(audio_path): - ... raise FileNotFoundError(f"audio file {audio_path} does not exist") - - ... base, ext = os.path.splitext(video_path) - ... temp_output = f"{base}_temp{ext}" - - ... try: - ... # Create ffmpeg command - ... command = [ - ... "ffmpeg", - ... "-y", # overwrite - ... "-i", - ... video_path, - ... "-i", - ... audio_path, - ... "-c:v", - ... "copy", # copy video stream - ... "-c:a", - ... "aac", # use AAC audio encoder - ... "-b:a", - ... "192k", # set audio bitrate (optional) - ... "-map", - ... "0:v:0", # select the first video stream - ... "-map", - ... "1:a:0", # select the first audio stream - ... "-shortest", # choose the shortest duration - ... temp_output, - ... ] - - ... # Execute the command - ... logging.info("Start merging video and audio...") - ... result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - - ... # Check result - ... if result.returncode != 0: - ... error_msg = f"FFmpeg execute failed: {result.stderr}" - ... logging.error(error_msg) - ... raise RuntimeError(error_msg) - - ... shutil.move(temp_output, video_path) - ... logging.info(f"Merge completed, saved to {video_path}") - - ... except Exception as e: - ... if os.path.exists(temp_output): - ... os.remove(temp_output) - ... logging.error(f"merge_video_audio failed with error: {e}") - - - >>> merge_video_audio("output.mp4", "audio.mp3") + >>> export_to_merged_video_audio("output.mp4", "audio.mp3") ``` """ diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 6c4290bf1d29..bc027d246227 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -41,7 +41,7 @@ from .deprecation_utils import deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module -from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video +from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video, export_to_merged_video_audio from .hub_utils import ( PushToHubMixin, _add_variant, diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py index 07cf46928a44..6fc79418121d 100644 --- a/src/diffusers/utils/export_utils.py +++ b/src/diffusers/utils/export_utils.py @@ -1,6 +1,9 @@ import io +import os import random +import shutil import struct +import subprocess import tempfile from contextlib import contextmanager from typing import List, Optional, Union @@ -207,3 +210,62 @@ def export_to_video( writer.append_data(frame) return output_video_path + + +def export_to_merged_video_audio(video_path: str, audio_path: str): + """ + Merge the video and audio into a new video, with the duration set to the shorter of the two, and overwrite the + original video file. + + Parameters: + video_path (str): Path to the original video file + audio_path (str): Path to the audio file + """ + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # Create ffmpeg command + command = [ + "ffmpeg", + "-y", # overwrite + "-i", + video_path, + "-i", + audio_path, + "-c:v", + "copy", # copy video stream + "-c:a", + "aac", # use AAC audio encoder + "-b:a", + "192k", # set audio bitrate (optional) + "-map", + "0:v:0", # select the first video stream + "-map", + "1:a:0", # select the first audio stream + "-shortest", # choose the shortest duration + temp_output, + ] + + # Execute the command + logger.info("Start merging video and audio...") + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # Check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + logger.info(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + logger.error(f"merge_video_audio failed with error: {e}") From 2f09d10930f1fb8aebf0d00d052e6b2008c5369a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 10:23:30 +0300 Subject: [PATCH 096/131] style --- src/diffusers/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index bc027d246227..fd3c5807b002 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -41,7 +41,7 @@ from .deprecation_utils import deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module -from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video, export_to_merged_video_audio +from .export_utils import export_to_gif, export_to_merged_video_audio, export_to_obj, export_to_ply, export_to_video from .hub_utils import ( PushToHubMixin, _add_variant, From d837dfc4f8736fdee092824ab165ccc1ca268254 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 14:15:39 +0300 Subject: [PATCH 097/131] Refactor audio injection logic Consolidates audio injection functionality by moving the `after_transformer_block` method into the `AudioInjector` class. This change improves code organization and encapsulation, making the injection process more modular and maintainable. --- .../transformers/transformer_wan_s2v.py | 117 +++++++++--------- 1 file changed, 59 insertions(+), 58 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index e17f935bd132..2f2429513fc5 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -340,13 +340,13 @@ def __init__( ): super().__init__() self.injected_block_id = {} - audio_injector_id = 0 + num_injection_layers = 0 for mod_name, mod in zip(all_modules_names, all_modules): if isinstance(mod, WanS2VTransformerBlock): for inject_id in inject_layer: if f"transformer_blocks.{inject_id}" in mod_name: - self.injected_block_id[inject_id] = audio_injector_id - audio_injector_id += 1 + self.injected_block_id[inject_id] = num_injection_layers + num_injection_layers += 1 # Cross-attention self.injector = nn.ModuleList( @@ -365,21 +365,55 @@ def __init__( ) self.injector_pre_norm_feat = nn.ModuleList( - [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(audio_injector_id)] + [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(num_injection_layers)] ) self.injector_pre_norm_vec = nn.ModuleList( - [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(audio_injector_id)] + [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(num_injection_layers)] ) if enable_adain: self.injector_adain_layers = nn.ModuleList( - [AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)] + [AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(num_injection_layers)] ) if need_adain_ont: self.injector_adain_output_layers = nn.ModuleList( - [nn.Linear(dim, dim) for _ in range(audio_injector_id)] + [nn.Linear(dim, dim) for _ in range(num_injection_layers)] ) + def forward( + self, + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ): + audio_attn_id = self.injected_block_id[block_idx] + + input_hidden_states = hidden_states[:, :original_sequence_length].clone() # B (F H W) C + input_hidden_states = input_hidden_states.unflatten(1, (merged_audio_emb_num_frames, -1)).flatten(0, 1) + + if self.config.enable_adain and self.config.adain_mode == "attn_norm": + attn_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id]( + input_hidden_states, temb=audio_emb_global[:, 0] + ) + else: + attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + + attention_kwargs = { + "max_seqlen_k": torch.ones( + attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device + ) + * attn_audio_emb.shape[1] + } + residual_out = self.audio_injector.injector[audio_attn_id]( + attn_hidden_states, attn_audio_emb, None, None, **attention_kwargs + ) + residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) + hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out + + return hidden_states class FramePackMotioner(nn.Module): def __init__( @@ -991,41 +1025,6 @@ def inject_motion( ) return hidden_states, seq_lens, rope_embs, mask_input - def after_transformer_block( - self, - block_idx, - hidden_states, - original_sequence_length, - merged_audio_emb_num_frames, - attn_audio_emb, - audio_emb_global, - ): - if block_idx in self.audio_injector.injected_block_id.keys(): - audio_attn_id = self.audio_injector.injected_block_id[block_idx] - - input_hidden_states = hidden_states[:, :original_sequence_length].clone() # B (F H W) C - input_hidden_states = input_hidden_states.unflatten(1, (merged_audio_emb_num_frames, -1)).flatten(0, 1) - - if self.config.enable_adain and self.config.adain_mode == "attn_norm": - attn_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id]( - input_hidden_states, temb=audio_emb_global[:, 0] - ) - else: - attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) - - attention_kwargs = { - "max_seqlen_k": torch.ones( - attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device - ) - * attn_audio_emb.shape[1] - } - residual_out = self.audio_injector.injector[audio_attn_id]( - attn_hidden_states, attn_audio_emb, None, None, **attention_kwargs - ) - residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) - hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out - - return hidden_states def forward( self, @@ -1163,27 +1162,29 @@ def forward( rotary_emb, attention_kwargs, ) - hidden_states = self.after_transformer_block( - block_idx, - hidden_states, - original_sequence_length, - merged_audio_emb_num_frames, - attn_audio_emb, - audio_emb_global, - ) + if block_idx in self.injected_block_id.keys(): + hidden_states = self.audio_injector( + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ) else: for block_idx, block in enumerate(self.blocks): hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_kwargs ) - hidden_states = self.after_transformer_block( - block_idx, - hidden_states, - original_sequence_length, - merged_audio_emb_num_frames, - attn_audio_emb, - audio_emb_global, - ) + if block_idx in self.injected_block_id.keys(): + hidden_states = self.audio_injector( + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ) hidden_states = hidden_states[:, :original_sequence_length] From d9fd75586157fbb25c5bb43e9d22962b534d9477 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 14:16:33 +0300 Subject: [PATCH 098/131] style --- src/diffusers/models/transformers/transformer_wan_s2v.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 2f2429513fc5..bf331ad1c835 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -402,9 +402,7 @@ def forward( attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) attention_kwargs = { - "max_seqlen_k": torch.ones( - attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device - ) + "max_seqlen_k": torch.ones(attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device) * attn_audio_emb.shape[1] } residual_out = self.audio_injector.injector[audio_attn_id]( @@ -415,6 +413,7 @@ def forward( return hidden_states + class FramePackMotioner(nn.Module): def __init__( self, @@ -1025,7 +1024,6 @@ def inject_motion( ) return hidden_states, seq_lens, rope_embs, mask_input - def forward( self, hidden_states: torch.Tensor, @@ -1193,8 +1191,7 @@ def forward( # Move the shift and scale tensors to the same device as hidden_states. # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. + # first device rather than the last device, which hidden_states ends up on. shift = shift.to(hidden_states.device) scale = scale.to(hidden_states.device) From e5ab1dded26cf5543f997410c684ffbe379ba785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:44:06 +0300 Subject: [PATCH 099/131] Update src/diffusers/models/transformers/transformer_wan_s2v.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_wan_s2v.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index bf331ad1c835..576c7dc9203b 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -331,7 +331,6 @@ def __init__( all_modules_names, dim=2048, num_heads=32, - inject_layer=[0, 27], enable_adain=False, adain_dim=2048, need_adain_ont=False, From c6e8fa4d5b7a16f449112c6d9b3d6eee5ab17872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:44:58 +0300 Subject: [PATCH 100/131] Update src/diffusers/models/transformers/transformer_wan_s2v.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_wan_s2v.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 576c7dc9203b..ecadd5676d61 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -327,8 +327,7 @@ def forward(self, features): class AudioInjector(nn.Module): def __init__( self, - all_modules, - all_modules_names, + num_injection_layers dim=2048, num_heads=32, enable_adain=False, From 62dc61e9851123ede333535ec47a5db5ba353799 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 16:00:38 +0300 Subject: [PATCH 101/131] Refactors audio injection logic Simplifies the audio injection process by directly passing injection layer indices to the `AudioInjector`. This removes the need for a depth-first search and dictionary creation within the injector, making the code more efficient and readable. --- .../transformers/transformer_wan_s2v.py | 40 ++++--------------- 1 file changed, 7 insertions(+), 33 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index ecadd5676d61..d51b87a3bc76 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -39,23 +39,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def torch_dfs(model: nn.Module, parent_name="root"): - module_names, modules = [], [] - current_name = parent_name if parent_name else "root" - module_names.append(current_name) - modules.append(model) - - for name, child in model.named_children(): - if parent_name: - child_name = f"{parent_name}.{name}" - else: - child_name = name - child_modules, child_names = torch_dfs(child, child_name) - module_names += child_names - modules += child_modules - return modules, module_names - - def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): # encoder_hidden_states is only passed for cross-attention if encoder_hidden_states is None: @@ -323,11 +306,11 @@ def forward(self, features): return res # b f n dim - class AudioInjector(nn.Module): def __init__( self, - num_injection_layers + num_injection_layers, + inject_layers, dim=2048, num_heads=32, enable_adain=False, @@ -337,14 +320,7 @@ def __init__( added_kv_proj_dim=None, ): super().__init__() - self.injected_block_id = {} - num_injection_layers = 0 - for mod_name, mod in zip(all_modules_names, all_modules): - if isinstance(mod, WanS2VTransformerBlock): - for inject_id in inject_layer: - if f"transformer_blocks.{inject_id}" in mod_name: - self.injected_block_id[inject_id] = num_injection_layers - num_injection_layers += 1 + self.injected_block_id = {inject_id: idx for inject_id, idx in zip(inject_layers, range(num_injection_layers))} # Cross-attention self.injector = nn.ModuleList( @@ -928,13 +904,11 @@ def __init__( ) # 4. Audio Injector - all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") self.audio_injector = AudioInjector( - all_modules, - all_modules_names, + num_injection_layers=len(audio_inject_layers), + inject_layers=audio_inject_layers, dim=inner_dim, num_heads=num_attention_heads, - inject_layer=audio_inject_layers, enable_adain=enable_adain, adain_dim=inner_dim, need_adain_ont=adain_mode != "attn_norm", @@ -1158,7 +1132,7 @@ def forward( rotary_emb, attention_kwargs, ) - if block_idx in self.injected_block_id.keys(): + if block_idx in self.audio_injector.injected_block_id.keys(): hidden_states = self.audio_injector( block_idx, hidden_states, @@ -1172,7 +1146,7 @@ def forward( hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_kwargs ) - if block_idx in self.injected_block_id.keys(): + if block_idx in self.audio_injector.injected_block_id.keys(): hidden_states = self.audio_injector( block_idx, hidden_states, From 6b98ebd3f91be7a6dfbaa766f42bc76194b615b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 16:08:00 +0300 Subject: [PATCH 102/131] Refactor adain mode handling --- .../models/transformers/transformer_wan_s2v.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index d51b87a3bc76..c6c2903713fd 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -314,12 +314,14 @@ def __init__( dim=2048, num_heads=32, enable_adain=False, + adain_mode="attn_norm", adain_dim=2048, - need_adain_ont=False, eps=1e-6, added_kv_proj_dim=None, ): super().__init__() + self.enable_adain = enable_adain + self.adain_mode = adain_mode self.injected_block_id = {inject_id: idx for inject_id, idx in zip(inject_layers, range(num_injection_layers))} # Cross-attention @@ -349,7 +351,7 @@ def __init__( self.injector_adain_layers = nn.ModuleList( [AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(num_injection_layers)] ) - if need_adain_ont: + if adain_mode != "attn_norm": self.injector_adain_output_layers = nn.ModuleList( [nn.Linear(dim, dim) for _ in range(num_injection_layers)] ) @@ -368,18 +370,18 @@ def forward( input_hidden_states = hidden_states[:, :original_sequence_length].clone() # B (F H W) C input_hidden_states = input_hidden_states.unflatten(1, (merged_audio_emb_num_frames, -1)).flatten(0, 1) - if self.config.enable_adain and self.config.adain_mode == "attn_norm": - attn_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id]( + if self.enable_adain and self.adain_mode == "attn_norm": + attn_hidden_states = self.injector_adain_layers[audio_attn_id]( input_hidden_states, temb=audio_emb_global[:, 0] ) else: - attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states) attention_kwargs = { "max_seqlen_k": torch.ones(attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device) * attn_audio_emb.shape[1] } - residual_out = self.audio_injector.injector[audio_attn_id]( + residual_out = self.injector[audio_attn_id]( attn_hidden_states, attn_audio_emb, None, None, **attention_kwargs ) residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) @@ -911,7 +913,7 @@ def __init__( num_heads=num_attention_heads, enable_adain=enable_adain, adain_dim=inner_dim, - need_adain_ont=adain_mode != "attn_norm", + adain_mode=adain_mode, eps=eps, ) From 8665fd573ada0e2903bcde17f3708b46f06d6970 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Sep 2025 18:53:49 +0300 Subject: [PATCH 103/131] style --- src/diffusers/models/transformers/transformer_wan_s2v.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index c6c2903713fd..2f67eb76d0ca 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -306,6 +306,7 @@ def forward(self, features): return res # b f n dim + class AudioInjector(nn.Module): def __init__( self, @@ -322,7 +323,7 @@ def __init__( super().__init__() self.enable_adain = enable_adain self.adain_mode = adain_mode - self.injected_block_id = {inject_id: idx for inject_id, idx in zip(inject_layers, range(num_injection_layers))} + self.injected_block_id = dict(zip(inject_layers, range(num_injection_layers))) # Cross-attention self.injector = nn.ModuleList( @@ -381,9 +382,7 @@ def forward( "max_seqlen_k": torch.ones(attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device) * attn_audio_emb.shape[1] } - residual_out = self.injector[audio_attn_id]( - attn_hidden_states, attn_audio_emb, None, None, **attention_kwargs - ) + residual_out = self.injector[audio_attn_id](attn_hidden_states, attn_audio_emb, None, None, **attention_kwargs) residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out From dd15817f278690911aa6f93d45a8ef514870b54c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 18 Sep 2025 08:16:23 +0300 Subject: [PATCH 104/131] revert --- src/diffusers/models/transformers/transformer_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index e2c0d3e508de..968a0369c243 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -206,7 +206,7 @@ def __init__( if added_kv_proj_dim is not None: self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) - self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) self.is_cross_attention = cross_attention_dim_head is not None From 9f4edb4821581b3b586566ede994a5ecb300332b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 18 Sep 2025 08:24:58 +0300 Subject: [PATCH 105/131] Take `AdaLayerNorm` from `normalization` --- .../transformers/transformer_wan_s2v.py | 41 +------------------ 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 2f67eb76d0ca..10f3e7f00939 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -30,7 +30,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin, get_parameter_dtype -from ..normalization import FP32LayerNorm +from ..normalization import FP32LayerNorm, AdaLayerNorm from .transformer_wan import ( WanAttention, ) @@ -166,43 +166,6 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): return hidden_states -class AdaLayerNorm(nn.Module): - r""" - Norm layer modified to incorporate timestep embeddings. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - output_dim (`int`, *optional*): Output dimension for the layer. - norm_elementwise_affine (`bool`, defaults to `False`): Whether to use elementwise affine in LayerNorm. - norm_eps (`float`, defaults to `1e-5`): Epsilon value for LayerNorm. - """ - - def __init__( - self, - embedding_dim: int, - output_dim: Optional[int] = None, - norm_elementwise_affine: bool = False, - norm_eps: float = 1e-5, - ): - super().__init__() - - output_dim = output_dim or embedding_dim * 2 - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, output_dim) - self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) - - def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: - temb = self.linear(self.silu(temb)) - - shift, scale = temb.chunk(2, dim=1) - shift = shift[:, None, :] - scale = scale[:, None, :] - - x = self.norm(x) * (1 + scale) + shift - return x - - class WanS2VCausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): super().__init__() @@ -350,7 +313,7 @@ def __init__( if enable_adain: self.injector_adain_layers = nn.ModuleList( - [AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(num_injection_layers)] + [AdaLayerNorm(embedding_dim=adain_dim, output_dim=dim * 2, chunk_dim=1) for _ in range(num_injection_layers)] ) if adain_mode != "attn_norm": self.injector_adain_output_layers = nn.ModuleList( From 52ffc494f3522a47ee1d37865dd98df30c98d33a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 18 Sep 2025 08:25:59 +0300 Subject: [PATCH 106/131] style --- src/diffusers/models/transformers/transformer_wan_s2v.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 10f3e7f00939..436a87b6cee5 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -30,7 +30,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin, get_parameter_dtype -from ..normalization import FP32LayerNorm, AdaLayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm from .transformer_wan import ( WanAttention, ) @@ -313,7 +313,10 @@ def __init__( if enable_adain: self.injector_adain_layers = nn.ModuleList( - [AdaLayerNorm(embedding_dim=adain_dim, output_dim=dim * 2, chunk_dim=1) for _ in range(num_injection_layers)] + [ + AdaLayerNorm(embedding_dim=adain_dim, output_dim=dim * 2, chunk_dim=1) + for _ in range(num_injection_layers) + ] ) if adain_mode != "attn_norm": self.injector_adain_output_layers = nn.ModuleList( From 6196332028531e297edb1c47aeddedcd28f9a247 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 18 Sep 2025 11:08:37 +0300 Subject: [PATCH 107/131] Refactor audio encoder with weighted average layer --- .../transformers/transformer_wan_s2v.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 436a87b6cee5..6847e52e7ef5 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -248,22 +248,32 @@ def forward(self, x): return x, x_local +class WeightedAveragelayer(nn.Module): + def __init__(self, num_layers): + super().__init__() + self.weights = torch.nn.Parameter(torch.ones((1, num_layers, 1, 1)) * 0.01) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(self.weights) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + + return weighted_feat + + class CausalAudioEncoder(nn.Module): def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_audio_token=4, need_global=False): super().__init__() + self.weighted_avg = WeightedAveragelayer(num_layers) self.encoder = WanS2VMotionEncoder( in_dim=dim, hidden_dim=out_dim, num_attention_heads=num_audio_token, need_global=need_global ) - weight = torch.ones((1, num_layers, 1, 1)) * 0.01 - - self.weights = torch.nn.Parameter(weight) - self.act = torch.nn.SiLU() def forward(self, features): # features B * num_layers * dim * video_length - weights = self.act(self.weights) - weights_sum = weights.sum(dim=1, keepdims=True) - weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + weighted_feat = self.weighted_avg(features) weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim res = self.encoder(weighted_feat) # b f n dim From 5c50519f7f58eb0686ece7c2921d67b58b3b2177 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 18 Sep 2025 11:12:00 +0300 Subject: [PATCH 108/131] style --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 6847e52e7ef5..1379fff3cc9f 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -259,7 +259,7 @@ def forward(self, features): weights = self.act(self.weights) weights_sum = weights.sum(dim=1, keepdims=True) weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f - + return weighted_feat From ee1f6fffaee585622e2483183e99e76acf9c90e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 13:59:39 +0300 Subject: [PATCH 109/131] Enhance image resizing functionality with additional options for resize and crop strategies --- src/diffusers/image_processor.py | 84 ++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 0e3082eada8a..fc444f2c02c2 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -437,10 +437,11 @@ def _resize_and_crop( image: PIL.Image.Image, width: int, height: int, + resize_type: str = "fit_within", + crop_type: str = "paste_center", ) -> PIL.Image.Image: r""" - Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center - the image within the dimensions, cropping the excess. + Resize and crop the image using different strategies. Args: image (`PIL.Image.Image`): @@ -449,28 +450,53 @@ def _resize_and_crop( The width to resize the image to. height (`int`): The height to resize the image to. + resize_type (`str`, optional): + How to resize the image. Options: + - "fit_within": Resize to fit within dimensions, maintaining aspect ratio (default) + - "min_dimension": Resize so smaller dimension becomes min(width, height) + crop_type (`str`, optional): + How to handle the final cropping/positioning. Options: + - "paste_center": Paste resized image on centered canvas, pad with black (default) + - "center_crop": Center crop to exact dimensions, pad with black if needed Returns: `PIL.Image.Image`: The resized and cropped image. """ - ratio = width / height - src_ratio = image.width / image.height - src_w = width if ratio > src_ratio else image.width * height // image.height - src_h = height if ratio <= src_ratio else image.height * width // image.width + if resize_type == "fit_within": + # Resize to fit within dimensions + ratio = width / height + src_ratio = image.width / image.height - resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) - res = Image.new("RGB", (width, height)) - res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) - return res + src_w = width if ratio > src_ratio else image.width * height // image.height + src_h = height if ratio <= src_ratio else image.height * width // image.width + + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample]) + elif resize_type == "min_dimension": + # # Resize so smaller dimension becomes min(width, height) + from torchvision.transforms import Resize + resized = Resize(min(height, width))(image) + else: + raise ValueError(f"Unknown resize_type: {resize_type}") + + if crop_type == "paste_center": + # Paste on canvas, center position + res = Image.new("RGB", (width, height), color=0) # Black background + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + return res + elif crop_type == "center_crop": + from torchvision.transforms import CenterCrop + return CenterCrop((height, width))(resized) + else: + raise ValueError(f"Unknown crop_type: {crop_type}") def resize( self, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], height: int, width: int, - resize_mode: str = "default", # "default", "fill", "crop" + resize_mode: str = "default", # "default", "fill", "crop", "resize_min_center_crop" ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Resize image. @@ -483,13 +509,15 @@ def resize( width (`int`): The width to resize to. resize_mode (`str`, *optional*, defaults to `default`): - The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit - within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, - will resize the image to fit within the specified width and height, maintaining the aspect ratio, and - then center the image within the dimensions, filling empty with data from image. If `crop`, will resize - the image to fit within the specified width and height, maintaining the aspect ratio, and then center - the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only - supported for PIL image input. + The resize mode to use, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If `default`, will + resize the image to fit within the specified width and height, and it may not maintaining the original + aspect ratio. If `fill`, will resize the image to fit within the specified width and height, maintaining + the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect + ratio, and then center the image within the dimensions, cropping the excess. If `resize_min_center_crop`, will + resize the image so that the smaller dimension becomes min(width, height), then center crop to exact + target dimensions (matches Wan2.2-S2V preprocessing). Note that resize_mode `fill`, `crop`, and `resize_min_center_crop` + are only supported for PIL image input. Returns: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: @@ -508,6 +536,8 @@ def resize( image = self._resize_and_fill(image, width, height) elif resize_mode == "crop": image = self._resize_and_crop(image, width, height) + elif resize_mode == "resize_min_center_crop": + image = self._resize_and_crop(image, width, height, resize_type="min_dimension", crop_type="center_crop") else: raise ValueError(f"resize_mode {resize_mode} is not supported") @@ -615,7 +645,7 @@ def preprocess( image: PipelineImageInput, height: Optional[int] = None, width: Optional[int] = None, - resize_mode: str = "default", # "default", "fill", "crop" + resize_mode: str = "default", # "default", "fill", "crop", "resize_min_center_crop" crops_coords: Optional[Tuple[int, int, int, int]] = None, ) -> torch.Tensor: """ @@ -631,13 +661,15 @@ def preprocess( width (`int`, *optional*): The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): - The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within - the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will - resize the image to fit within the specified width and height, maintaining the aspect ratio, and then - center the image within the dimensions, filling empty with data from image. If `crop`, will resize the - image to fit within the specified width and height, maintaining the aspect ratio, and then center the - image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only - supported for PIL image input. + The resize mode, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If `default`, will resize + the image to fit within the specified width and height, and it may not maintaining the original aspect + ratio. If `fill`, will resize the image to fit within the specified width and height, maintaining the + aspect ratio, and then center the image within the dimensions, filling empty with data from image. If + `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, + and then center the image within the dimensions, cropping the excess. If `resize_min_center_crop`, will resize the + image so that the smaller dimension becomes min(width, height), then center crop to exact target + dimensions (matches Wan2.2 preprocessing). Note that resize_mode `fill`, `crop`, and `resize_min_center_crop` are + only supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. From e15d3f6c38f38c6e992842b0ec2fc76e4377319f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 14:00:26 +0300 Subject: [PATCH 110/131] Add resize_mode parameter to preprocess_video for flexible video resizing options --- src/diffusers/video_processor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 59b59b47d2c7..d97f6bfe4b61 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -25,7 +25,7 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" - def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None, resize_mode: str = "default") -> torch.Tensor: r""" Preprocesses input video(s). @@ -49,6 +49,9 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[ width (`int`, *optional*`, defaults to `None`): The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default`, `fill`, `crop`, or `center_crop`. See `VaeImageProcessor.preprocess` + for detailed descriptions of each mode. """ if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: warnings.warn( @@ -79,7 +82,7 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[ "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) - video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + video = torch.stack([self.preprocess(img, height=height, width=width, resize_mode=resize_mode) for img in video], dim=0) # move the number of channels before the number of frames. video = video.permute(0, 2, 1, 3, 4) From bc2165acce7209da0a91b1c72a3a3576a691d214 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 14:08:31 +0300 Subject: [PATCH 111/131] Refactor video processing in WanSpeechToVideoPipeline to support bilinear resampling and adjust frame chunk settings Updates the speech-to-video pipeline to perform a decode-encode cycle within the generation loop for each video chunk. This change improves temporal consistency between chunks by using the pixels of the previously generated frames, rather than their latents, to condition the next chunk. Key changes include: - Modifying the generation loop to decode latents into video frames, update the conditioning pixels, and then re-encode them for the next iteration's motion latents. - Setting the default `num_frames_per_chunk` to 80 and adjusting the corresponding frame logic. - Enabling `bilinear` resampling in the `VideoProcessor`. --- .../pipelines/wan/pipeline_wan_s2v.py | 86 +++++++++---------- 1 file changed, 40 insertions(+), 46 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 512c94a552d4..ff1323ab47ac 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -288,7 +288,7 @@ def __init__( self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, resample="bilinear") self.audio_processor = audio_processor self.motion_frames = 73 self.drop_first_motion = True @@ -352,7 +352,7 @@ def encode_audio( res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True) feat = torch.cat(res.hidden_states) - feat = linear_interpolation(feat, input_fps=50, output_fps=30) + feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) audio_embed = feat.to(torch.float32) # Encoding for the motion @@ -568,7 +568,7 @@ def prepare_latents( num_channels_latents: int = 16, height: int = 480, width: int = 832, - num_frames_per_chunk: int = 81, + num_frames_per_chunk: int = 80, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -627,14 +627,14 @@ def prepare_latents( pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std ) # Encode motion latents + videos_last_pixels = motion_pixels.detach() if init_first_frame: self.drop_first_motion = False - motion_pixels[:, :, -6:] = latent_condition + motion_pixels[:, :, -6:] = video_condition motion_latents = retrieve_latents(self.vae.encode(motion_pixels), sample_mode="argmax") motion_latents = (motion_latents - latents_mean) * latents_std - videos_last_latents = motion_latents.detach() - return latents, latent_condition, videos_last_latents, motion_latents, pose_condition + return latents, latent_condition, videos_last_pixels, motion_latents, pose_condition else: return latents @@ -706,7 +706,7 @@ def __call__( pose_video_path_or_url: Optional[str] = None, height: int = 480, width: int = 832, - num_frames_per_chunk: int = 81, + num_frames_per_chunk: int = 80, num_inference_steps: int = 40, guidance_scale: float = 4.5, num_videos_per_prompt: Optional[int] = 1, @@ -751,9 +751,8 @@ def __call__( The height of the generated video. width (`int`, defaults to `832`): The width of the generated video. - num_frames_per_chunk (`int`, defaults to `81`): - The number of frames in each chunk of the generated video. `num_frames_per_chunk` - 1 should be a - multiple of 4. + num_frames_per_chunk (`int`, defaults to `80`): + The number of frames in each chunk of the generated video. `num_frames_per_chunk` should be a multiple of 4. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -839,12 +838,12 @@ def __call__( audio_embeds, ) - if num_frames_per_chunk % self.vae_scale_factor_temporal != 1: - logger.warning( - f"`num_frames_per_chunk - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." - ) + if num_frames_per_chunk % self.vae_scale_factor_temporal != 0: num_frames_per_chunk = ( - num_frames_per_chunk // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames_per_chunk // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + ) + logger.warning( + f"`num_frames_per_chunk` had to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number: {num_frames_per_chunk}" ) num_frames_per_chunk = max(num_frames_per_chunk, 1) @@ -902,11 +901,11 @@ def __call__( target_fps=sampling_fps, reverse=True, ) - pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to( + pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width, resize_mode="resize_min_center_crop").to( device, dtype=torch.float32 ) - all_latents = [] + video_chunks = [] for r in range(num_chunks): latents_outputs = self.prepare_latents( image if r == 0 else None, @@ -926,7 +925,7 @@ def __call__( ) if r == 0: - latents, condition, videos_last_latents, motion_latents, pose_condition = latents_outputs + latents, condition, videos_last_pixels, motion_latents, pose_condition = latents_outputs else: latents = latents_outputs @@ -1013,45 +1012,40 @@ def __call__( else: decode_latents = torch.cat([condition, latents], dim=2) - # Work in latent space - no decode-encode cycle - num_latent_frames = (num_frames_per_chunk + 3) // self.vae_scale_factor_temporal - segment_latents = decode_latents[:, :, -num_latent_frames:] - if self.drop_first_motion and r == 0: - segment_latents = segment_latents[:, :, (3 + 3) // self.vae_scale_factor_temporal :] + decode_latents = decode_latents.to(self.vae.dtype) - num_latent_overlap_frames = min(latent_motion_frames, segment_latents.shape[2]) - videos_last_latents = torch.cat( - [ - videos_last_latents[:, :, num_latent_overlap_frames:], - segment_latents[:, :, -num_latent_overlap_frames:], - ], - dim=2, + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(decode_latents.device, decode_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + decode_latents.device, decode_latents.dtype ) + decode_latents = decode_latents / latents_std + latents_mean + video = self.vae.decode(decode_latents, return_dict=False)[0] + video = video[:, :, -(num_frames_per_chunk):] + + if self.drop_first_motion and r == 0: + video = video[:, :, 3:] + + num_overlap_frames = min(self.motion_frames, video.shape[2]) + videos_last_pixels = torch.cat([videos_last_pixels[:, :, num_overlap_frames:], video[:, :, -num_overlap_frames:]], dim=2) # Update motion_latents for next iteration - motion_latents = videos_last_latents.to(dtype=motion_latents.dtype, device=motion_latents.device) + motion_latents = retrieve_latents(self.vae.encode(videos_last_pixels), sample_mode="argmax") + motion_latents = (motion_latents - latents_mean) * latents_std - # Accumulate latents so as to decode them all at once at the end - all_latents.append(segment_latents) + video_chunks.append(video) - latents = torch.cat(all_latents, dim=2) + video_chunks = torch.cat(video_chunks, dim=2) self._current_timestep = None if not output_type == "latent": - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + video = self.video_processor.postprocess_video(video_chunks, output_type=output_type) else: + # TODO video = latents # Offload all models From 0bf98b665ac05451dcba3eb7ce83fb57f36df13c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 14:09:04 +0300 Subject: [PATCH 112/131] style --- src/diffusers/image_processor.py | 43 +++++++++++-------- .../pipelines/wan/pipeline_wan_s2v.py | 13 +++--- src/diffusers/video_processor.py | 12 ++++-- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index fc444f2c02c2..086c013458ce 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -476,6 +476,7 @@ def _resize_and_crop( elif resize_type == "min_dimension": # # Resize so smaller dimension becomes min(width, height) from torchvision.transforms import Resize + resized = Resize(min(height, width))(image) else: raise ValueError(f"Unknown resize_type: {resize_type}") @@ -487,6 +488,7 @@ def _resize_and_crop( return res elif crop_type == "center_crop": from torchvision.transforms import CenterCrop + return CenterCrop((height, width))(resized) else: raise ValueError(f"Unknown crop_type: {crop_type}") @@ -509,15 +511,16 @@ def resize( width (`int`): The width to resize to. resize_mode (`str`, *optional*, defaults to `default`): - The resize mode to use, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If `default`, will - resize the image to fit within the specified width and height, and it may not maintaining the original - aspect ratio. If `fill`, will resize the image to fit within the specified width and height, maintaining - the aspect ratio, and then center the image within the dimensions, filling empty with data from image. - If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect - ratio, and then center the image within the dimensions, cropping the excess. If `resize_min_center_crop`, will - resize the image so that the smaller dimension becomes min(width, height), then center crop to exact - target dimensions (matches Wan2.2-S2V preprocessing). Note that resize_mode `fill`, `crop`, and `resize_min_center_crop` - are only supported for PIL image input. + The resize mode to use, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If + `default`, will resize the image to fit within the specified width and height, and it may not + maintaining the original aspect ratio. If `fill`, will resize the image to fit within the specified + width and height, maintaining the aspect ratio, and then center the image within the dimensions, + filling empty with data from image. If `crop`, will resize the image to fit within the specified width + and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the + excess. If `resize_min_center_crop`, will resize the image so that the smaller dimension becomes + min(width, height), then center crop to exact target dimensions (matches Wan2.2-S2V preprocessing). + Note that resize_mode `fill`, `crop`, and `resize_min_center_crop` are only supported for PIL image + input. Returns: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: @@ -537,7 +540,9 @@ def resize( elif resize_mode == "crop": image = self._resize_and_crop(image, width, height) elif resize_mode == "resize_min_center_crop": - image = self._resize_and_crop(image, width, height, resize_type="min_dimension", crop_type="center_crop") + image = self._resize_and_crop( + image, width, height, resize_type="min_dimension", crop_type="center_crop" + ) else: raise ValueError(f"resize_mode {resize_mode} is not supported") @@ -661,15 +666,15 @@ def preprocess( width (`int`, *optional*): The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): - The resize mode, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If `default`, will resize - the image to fit within the specified width and height, and it may not maintaining the original aspect - ratio. If `fill`, will resize the image to fit within the specified width and height, maintaining the - aspect ratio, and then center the image within the dimensions, filling empty with data from image. If - `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, - and then center the image within the dimensions, cropping the excess. If `resize_min_center_crop`, will resize the - image so that the smaller dimension becomes min(width, height), then center crop to exact target - dimensions (matches Wan2.2 preprocessing). Note that resize_mode `fill`, `crop`, and `resize_min_center_crop` are - only supported for PIL image input. + The resize mode, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If `default`, + will resize the image to fit within the specified width and height, and it may not maintaining the + original aspect ratio. If `fill`, will resize the image to fit within the specified width and height, + maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data + from image. If `crop`, will resize the image to fit within the specified width and height, maintaining + the aspect ratio, and then center the image within the dimensions, cropping the excess. If + `resize_min_center_crop`, will resize the image so that the smaller dimension becomes min(width, + height), then center crop to exact target dimensions (matches Wan2.2 preprocessing). Note that + resize_mode `fill`, `crop`, and `resize_min_center_crop` are only supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index ff1323ab47ac..315b10e77390 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -752,7 +752,8 @@ def __call__( width (`int`, defaults to `832`): The width of the generated video. num_frames_per_chunk (`int`, defaults to `80`): - The number of frames in each chunk of the generated video. `num_frames_per_chunk` should be a multiple of 4. + The number of frames in each chunk of the generated video. `num_frames_per_chunk` should be a multiple + of 4. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -901,9 +902,9 @@ def __call__( target_fps=sampling_fps, reverse=True, ) - pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width, resize_mode="resize_min_center_crop").to( - device, dtype=torch.float32 - ) + pose_video = self.video_processor.preprocess_video( + pose_video, height=height, width=width, resize_mode="resize_min_center_crop" + ).to(device, dtype=torch.float32) video_chunks = [] for r in range(num_chunks): @@ -1030,7 +1031,9 @@ def __call__( video = video[:, :, 3:] num_overlap_frames = min(self.motion_frames, video.shape[2]) - videos_last_pixels = torch.cat([videos_last_pixels[:, :, num_overlap_frames:], video[:, :, -num_overlap_frames:]], dim=2) + videos_last_pixels = torch.cat( + [videos_last_pixels[:, :, num_overlap_frames:], video[:, :, -num_overlap_frames:]], dim=2 + ) # Update motion_latents for next iteration motion_latents = retrieve_latents(self.vae.encode(videos_last_pixels), sample_mode="argmax") diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index d97f6bfe4b61..dc6623e1e472 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -25,7 +25,9 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" - def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None, resize_mode: str = "default") -> torch.Tensor: + def preprocess_video( + self, video, height: Optional[int] = None, width: Optional[int] = None, resize_mode: str = "default" + ) -> torch.Tensor: r""" Preprocesses input video(s). @@ -50,8 +52,8 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[ The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): - The resize mode, can be one of `default`, `fill`, `crop`, or `center_crop`. See `VaeImageProcessor.preprocess` - for detailed descriptions of each mode. + The resize mode, can be one of `default`, `fill`, `crop`, or `center_crop`. See + `VaeImageProcessor.preprocess` for detailed descriptions of each mode. """ if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: warnings.warn( @@ -82,7 +84,9 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[ "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) - video = torch.stack([self.preprocess(img, height=height, width=width, resize_mode=resize_mode) for img in video], dim=0) + video = torch.stack( + [self.preprocess(img, height=height, width=width, resize_mode=resize_mode) for img in video], dim=0 + ) # move the number of channels before the number of frames. video = video.permute(0, 2, 1, 3, 4) From 226a4511a7da7e791d13192364127732819c07ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 15:51:19 +0300 Subject: [PATCH 113/131] Add `Motioner` class for _simple_ motion processing in `WanS2VTransformer3DModel` --- scripts/convert_wan_to_diffusers.py | 1 + .../transformers/transformer_wan_s2v.py | 77 ++++++++++++------- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 8f1f2bb43a79..2b78b36a60dd 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -154,6 +154,7 @@ "cond_encoder.weight": "condition_embedder.pose_embedder.weight", "cond_encoder.bias": "condition_embedder.pose_embedder.bias", "trainable_cond_mask": "trainable_condition_mask", + "patch_embedding": "motion_in.patch_embedding", # Audio injector attention mappings - convert original q/k/v/o format to diffusers format **{ f"audio_injector.injector.{i}.{src}": f"audio_injector.injector.{i}.{dst}" diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 1379fff3cc9f..2db142522067 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -483,6 +483,43 @@ def forward(self, motion_latents, add_last_motion=2): return motion_lat, motion_rope_emb +class Motioner(nn.Module): + def __init__(self, inner_dim, num_attention_heads, patch_size=(1, 2, 2), in_channels=16, rope_max_seq_len=1024): + super().__init__() + self.inner_dim = inner_dim + self.num_attention_heads = num_attention_heads + + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.rope = WanS2VRotaryPosEmbed( + inner_dim // num_attention_heads, patch_size, rope_max_seq_len, num_attention_heads + ) + + def forward(self, motion_latents): + latent_motion_frames = motion_latents.shape[2] + mot = self.patch_embedding(motion_latents) + + height, width = mot.shape[3], mot.shape[4] + flat_mot = mot.flatten(2).transpose(1, 2).contiguous() + motion_grid_sizes = [ + [ + torch.tensor([-latent_motion_frames, 0, 0]).unsqueeze(0), + torch.tensor([0, height, width]).unsqueeze(0), + torch.tensor([latent_motion_frames, height, width]).unsqueeze(0), + ] + ] + motion_rope_emb = self.rope( + flat_mot.detach().view( + flat_mot.shape[0], + flat_mot.shape[1], + self.num_attention_heads, + self.inner_dim // self.num_attention_heads, + ), + motion_grid_sizes, + ) + + return flat_mot, motion_rope_emb + + class WanTimeTextAudioPoseEmbedding(nn.Module): def __init__( self, @@ -855,6 +892,14 @@ def __init__( drop_mode=framepack_drop_mode, patch_size=patch_size, ) + else: + self.motion_in = Motioner( + inner_dim=inner_dim, + num_attention_heads=num_attention_heads, + patch_size=patch_size, + in_channels=in_channels, + rope_max_seq_len=rope_max_seq_len, + ) self.trainable_condition_mask = nn.Embedding(3, inner_dim) @@ -900,36 +945,12 @@ def __init__( self.gradient_checkpointing = False def process_motion(self, motion_latents, drop_motion_frames=False): + flattern_mot, mot_remb = self.motion_in(motion_latents) + if drop_motion_frames or motion_latents[0].shape[1] == 0: return [], [] - self.latent_motion_frames = motion_latents[0].shape[1] - mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents] - batch_size = len(mot) - - mot_remb = [] - flattern_mot = [] - for bs in range(batch_size): - height, width = mot[bs].shape[3], mot[bs].shape[4] - flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous() - motion_grid_sizes = [ - [ - torch.tensor([-self.latent_motion_frames, 0, 0]).unsqueeze(0), - torch.tensor([0, height, width]).unsqueeze(0), - torch.tensor([self.latent_motion_frames, height, width]).unsqueeze(0), - ] - ] - motion_rope_emb = self.rope( - flat_mot.detach().view( - 1, - flat_mot.shape[1], - self.config.num_attention_heads, - self.inner_dim // self.config.num_attention_heads, - ), - motion_grid_sizes, - ) - mot_remb.append(motion_rope_emb) - flattern_mot.append(flat_mot) - return flattern_mot, mot_remb + else: + return flattern_mot, mot_remb def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) From 7122b619bb4a308ff45aad953999a8a5af82082a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 17:59:30 +0300 Subject: [PATCH 114/131] Add `WanS2VCausalConvLayer` for modularism --- scripts/convert_wan_to_diffusers.py | 5 ++- .../transformers/transformer_wan_s2v.py | 43 ++++++++++--------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 2b78b36a60dd..c64bd6e9c3f5 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -148,8 +148,9 @@ "attn2.to_v_img": "attn2.add_v_proj", "attn2.norm_k_img": "attn2.norm_added_k", # S2V-specific audio component mappings - "casual_audio_encoder.encoder": "condition_embedder.causal_audio_encoder.encoder", - "casual_audio_encoder.weights": "condition_embedder.causal_audio_encoder.weights", + "casual_audio_encoder.encoder.conv2.conv": "condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv", + "casual_audio_encoder.encoder.conv3.conv": "condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv", + "casual_audio_encoder.weights": "condition_embedder.causal_audio_encoder.weighted_avg.weights", # Pose condition encoder mappings "cond_encoder.weight": "condition_embedder.pose_embedder.weight", "cond_encoder.bias": "condition_embedder.pose_embedder.bias", diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 2db142522067..d4a51cdbc7b3 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -180,6 +180,25 @@ def forward(self, x): return self.conv(x) +class WanS2VCausalConvLayer(nn.Module): + """A layer that combines causal convolution, normalization, and activation in sequence.""" + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", eps=1e-6, **kwargs): + super().__init__() + + self.conv = WanS2VCausalConv1d(chan_in, chan_out, kernel_size, stride, dilation, pad_mode, **kwargs) + self.norm = nn.LayerNorm(chan_out, elementwise_affine=False, eps=eps) + self.act = nn.SiLU() + + def forward(self, x): + x = x.permute(0, 2, 1) + x = self.conv(x) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = self.act(x) + return x + + class WanS2VMotionEncoder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads: int, need_global: bool = True): super().__init__() @@ -189,16 +208,14 @@ def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads: int, need_ self.conv1_local = WanS2VCausalConv1d(in_dim, hidden_dim // 4 * num_attention_heads, 3, stride=1) if need_global: self.conv1_global = WanS2VCausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) - self.act = nn.SiLU() - self.conv2 = WanS2VCausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) - self.conv3 = WanS2VCausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + self.conv2 = WanS2VCausalConvLayer(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = WanS2VCausalConvLayer(hidden_dim // 2, hidden_dim, 3, stride=2) if need_global: self.final_linear = nn.Linear(hidden_dim, hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6) - self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6) - self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) + self.act = nn.SiLU() self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) @@ -210,16 +227,8 @@ def forward(self, x): x = x.unflatten(1, (self.num_attention_heads, -1)).permute(0, 1, 3, 2).flatten(0, 1) x = self.norm1(x) x = self.act(x) - x = x.permute(0, 2, 1) x = self.conv2(x) - x = x.permute(0, 2, 1) - x = self.norm2(x) - x = self.act(x) - x = x.permute(0, 2, 1) x = self.conv3(x) - x = x.permute(0, 2, 1) - x = self.norm3(x) - x = self.act(x) x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1) x = torch.cat([x, padding], dim=-2) @@ -232,16 +241,8 @@ def forward(self, x): x = x.permute(0, 2, 1) x = self.norm1(x) x = self.act(x) - x = x.permute(0, 2, 1) x = self.conv2(x) - x = x.permute(0, 2, 1) - x = self.norm2(x) - x = self.act(x) - x = x.permute(0, 2, 1) x = self.conv3(x) - x = x.permute(0, 2, 1) - x = self.norm3(x) - x = self.act(x) x = self.final_linear(x) x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) From 9dab88f2cffe89e07a6340115fb39a36f9359a13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 18:20:54 +0300 Subject: [PATCH 115/131] style --- src/diffusers/models/transformers/transformer_wan_s2v.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index d4a51cdbc7b3..446f4bf04d83 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -183,7 +183,9 @@ def forward(self, x): class WanS2VCausalConvLayer(nn.Module): """A layer that combines causal convolution, normalization, and activation in sequence.""" - def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", eps=1e-6, **kwargs): + def __init__( + self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", eps=1e-6, **kwargs + ): super().__init__() self.conv = WanS2VCausalConv1d(chan_in, chan_out, kernel_size, stride, dilation, pad_mode, **kwargs) From 70ef9c3c1c62aa12055c167f32b44c2b47fe2dbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 18:21:10 +0300 Subject: [PATCH 116/131] Add CP configs --- src/diffusers/models/transformers/transformer_wan_s2v.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 446f4bf04d83..d10b783bcea0 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -70,6 +70,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t class WanS2VAttnProcessor: _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -140,7 +141,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, - attention_kwargs=kwargs, + parallel_config=self._parallel_config, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -153,7 +154,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): dropout_p=0.0, is_causal=False, backend=self._attention_backend, - attention_kwargs=kwargs, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) From 2d6176b1f83dabf9eaf59c9a36aaca60ab54aa96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 18:26:13 +0300 Subject: [PATCH 117/131] Update attention dispather usage --- src/diffusers/models/attention_dispatch.py | 23 ++++--------------- .../transformers/transformer_wan_s2v.py | 18 ++++----------- 2 files changed, 9 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 48d03e308756..0a2ad681237b 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1246,24 +1246,11 @@ def _flash_varlen_attention( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - if max_seqlen_k is not None: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - - (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, _) = _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, max_seqlen_k, attn_mask=attn_mask, device=query.device - ) - max_seqlen_k = seq_len_kv - else: - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) key_valid, value_valid = [], [] for b in range(batch_size): diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index d10b783bcea0..bf96f4d73687 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -85,7 +85,6 @@ def __call__( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: @@ -358,11 +357,7 @@ def forward( else: attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states) - attention_kwargs = { - "max_seqlen_k": torch.ones(attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device) - * attn_audio_emb.shape[1] - } - residual_out = self.injector[audio_attn_id](attn_hidden_states, attn_audio_emb, None, None, **attention_kwargs) + residual_out = self.injector[audio_attn_id](attn_hidden_states, attn_audio_emb, None, None) residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out @@ -745,7 +740,6 @@ def forward( encoder_hidden_states: torch.Tensor, temb: Tuple[torch.Tensor, torch.Tensor], rotary_emb: torch.Tensor, - attention_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: seg_idx = temb[1].item() seg_idx = min(max(0, seg_idx), hidden_states.shape[1]) @@ -773,7 +767,7 @@ def forward( norm_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states) # 1. Self-attention - attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb, **(attention_kwargs or {})) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) z = [] for i in range(2): z.append(attn_output[:, seg_idx[i] : seg_idx[i + 1]] * gate_msa[:, i : i + 1]) @@ -782,7 +776,7 @@ def forward( # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) - attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None, **(attention_kwargs or {})) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) hidden_states = hidden_states + attn_output # 3. Feed-forward @@ -1102,7 +1096,6 @@ def forward( drop_motion_frames, add_last_motion, ) - attention_kwargs = {"max_seqlen_k": sequence_length.item()} hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype) @@ -1132,7 +1125,6 @@ def forward( encoder_hidden_states, timestep_proj, rotary_emb, - attention_kwargs, ) if block_idx in self.audio_injector.injected_block_id.keys(): hidden_states = self.audio_injector( @@ -1145,9 +1137,7 @@ def forward( ) else: for block_idx, block in enumerate(self.blocks): - hidden_states = block( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_kwargs - ) + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) if block_idx in self.audio_injector.injected_block_id.keys(): hidden_states = self.audio_injector( block_idx, From 77da3e369317d76974483687feb74c500586ad9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 18:35:43 +0300 Subject: [PATCH 118/131] Refactor example docstring for aspect ratio resizing and update num_frames_per_chunk --- src/diffusers/pipelines/wan/pipeline_wan_s2v.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index 315b10e77390..aa72f0fc24f3 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -134,13 +134,7 @@ ... return target_height, target_width - >>> def aspect_ratio_resize(image, pipe, max_area): - ... height, width = get_size_less_than_area(image.size[1], image.size[0], target_area=max_area) - ... image = image.resize((width, height)) - ... return image, height, width - - - >>> image, height, width = aspect_ratio_resize(first_frame, pipe, 480 * 832) + >>> height, width = get_size_less_than_area(image.height, image.width, target_area=480 * 832) >>> prompt = "Einstein singing a song." @@ -151,7 +145,7 @@ ... sampling_rate=sampling_rate, ... height=height, ... width=width, - ... num_frames_per_chunk=81, + ... num_frames_per_chunk=80, ... # pose_video_path_or_url=pose_video_path_or_url, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=16) From 9f61f5c0073cfa7ddfdb592433962c6c8f849e54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 24 Sep 2025 21:01:22 +0300 Subject: [PATCH 119/131] up docs --- docs/source/en/api/pipelines/wan.md | 67 ++--------------------------- 1 file changed, 3 insertions(+), 64 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index c206a6cb02c1..6c1920f4b139 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -253,7 +253,7 @@ The example below demonstrates how to use the speech-to-video pipeline to genera import numpy as np, math import torch from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline -from diffusers.utils import export_to_video, load_image, load_audio, load_video +from diffusers.utils import export_to_merged_video_audio, load_image, load_audio, load_video from transformers import Wav2Vec2ForCTC import requests from PIL import Image @@ -336,18 +336,13 @@ def get_size_less_than_area(height, return target_height, target_width -def aspect_ratio_resize(image, pipe, max_area): - height, width = get_size_less_than_area(image.size[1], image.size[0], target_area=max_area) - image = image.resize((width, height)) - return image, height, width - -image, height, width = aspect_ratio_resize(first_frame, pipe, 480*832) +height, width = get_size_less_than_area(first_frame.height, first_frame.width, 480*832) prompt = "Einstein singing a song." output = pipe( prompt=prompt, image=image, audio=audio, sampling_rate=sampling_rate, - height=height, width=width, num_frames_per_chunk=81, + height=height, width=width, num_frames_per_chunk=80, #pose_video_path_or_url=pose_video_path_or_url, ).frames[0] export_to_video(output, "output.mp4", fps=16) @@ -355,62 +350,6 @@ export_to_video(output, "output.mp4", fps=16) # Lastly, we need to merge the video and audio into a new video, with the duration set to # the shorter of the two and overwrite the original video file. -import os, logging, subprocess, shutil - -def merge_video_audio(video_path: str, audio_path: str): - logging.basicConfig(level=logging.INFO) - - if not os.path.exists(video_path): - raise FileNotFoundError(f"video file {video_path} does not exist") - if not os.path.exists(audio_path): - raise FileNotFoundError(f"audio file {audio_path} does not exist") - - base, ext = os.path.splitext(video_path) - temp_output = f"{base}_temp{ext}" - - try: - # Create ffmpeg command - command = [ - 'ffmpeg', - '-y', # overwrite - '-i', - video_path, - '-i', - audio_path, - '-c:v', - 'copy', # copy video stream - '-c:a', - 'aac', # use AAC audio encoder - '-b:a', - '192k', # set audio bitrate (optional) - '-map', - '0:v:0', # select the first video stream - '-map', - '1:a:0', # select the first audio stream - '-shortest', # choose the shortest duration - temp_output - ] - - # Execute the command - logging.info("Start merging video and audio...") - result = subprocess.run( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - - # Check result - if result.returncode != 0: - error_msg = f"FFmpeg execute failed: {result.stderr}" - logging.error(error_msg) - raise RuntimeError(error_msg) - - shutil.move(temp_output, video_path) - logging.info(f"Merge completed, saved to {video_path}") - - except Exception as e: - if os.path.exists(temp_output): - os.remove(temp_output) - logging.error(f"merge_video_audio failed with error: {e}") - -merge_video_audio("output.mp4", "audio.mp3") ``` From 079dd7d73004ac23cf708a3059132f0ed32015a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 25 Sep 2025 16:33:30 +0300 Subject: [PATCH 120/131] up docs --- docs/source/en/api/pipelines/wan.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 6c1920f4b139..0d6513d88648 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -253,7 +253,7 @@ The example below demonstrates how to use the speech-to-video pipeline to genera import numpy as np, math import torch from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline -from diffusers.utils import export_to_merged_video_audio, load_image, load_audio, load_video +from diffusers.utils import export_to_merged_video_audio, load_image, load_audio, load_video, export_to_video from transformers import Wav2Vec2ForCTC import requests from PIL import Image @@ -349,7 +349,7 @@ export_to_video(output, "output.mp4", fps=16) # Lastly, we need to merge the video and audio into a new video, with the duration set to # the shorter of the two and overwrite the original video file. - +export_to_merged_video_audio("output.mp4", "audio.mp3") ``` From 1c553a1552b9b64784ddcac1fffbfb0d7e007f04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 25 Sep 2025 18:59:38 +0300 Subject: [PATCH 121/131] up test --- .../pipelines/wan/test_wan_speech_to_video.py | 91 ++++++++++--------- 1 file changed, 48 insertions(+), 43 deletions(-) diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py index 3bbe673d2de4..5230f31088eb 100644 --- a/tests/pipelines/wan/test_wan_speech_to_video.py +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -17,7 +17,7 @@ import numpy as np import torch from PIL import Image -from transformers import AutoTokenizer, T5EncoderModel +from transformers import AutoTokenizer, T5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor from diffusers import ( AutoencoderKLWan, @@ -76,7 +76,7 @@ def get_dummy_components(self): in_channels=16, out_channels=16, text_dim=32, - freq_dim=256, + freq_dim=16, ffn_dim=32, num_layers=3, cross_attn_norm=True, @@ -84,12 +84,18 @@ def get_dummy_components(self): rope_max_seq_len=32, ) + torch.manual_seed(0) + audio_encoder = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") + audio_processor = Wav2Vec2Processor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") + components = { "transformer": transformer, "vae": vae, "scheduler": scheduler, "text_encoder": text_encoder, "tokenizer": tokenizer, + "audio_encoder": audio_encoder, + "audio_processor": audio_processor, } return components @@ -99,26 +105,30 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) - num_frames = 17 height = 16 width = 16 - video = [Image.new("RGB", (height, width))] * num_frames - mask = [Image.new("L", (height, width), 0)] * num_frames + image = Image.new("RGB", (width, height)) + + sampling_rate = 16000 + audio_length = 0.5 + audio = np.random.rand(int(sampling_rate * audio_length)).astype(np.float32) inputs = { - "video": video, - "mask": mask, - "prompt": "dance monkey", - "negative_prompt": "negative", + "image": image, + "audio": audio, + "sampling_rate": sampling_rate, + "prompt": "A person speaking", + "negative_prompt": "low quality", "generator": generator, "num_inference_steps": 2, - "guidance_scale": 6.0, - "height": 16, - "width": 16, - "num_frames": num_frames, + "guidance_scale": 4.5, + "height": height, + "width": width, + "num_frames_per_chunk": 5, "max_sequence_length": 16, "output_type": "pt", + "pose_video_path_or_url": None, } return inputs @@ -132,18 +142,14 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames[0] - self.assertEqual(video.shape, (17, 3, 16, 16)) - - # fmt: off - expected_slice = [0.4523, 0.45198, 0.44872, 0.45326, 0.45211, 0.45258, 0.45344, 0.453, 0.52431, 0.52572, 0.50701, 0.5118, 0.53717, 0.53093, 0.50557, 0.51402] - # fmt: on + self.assertEqual(video.shape, (5, 3, 16, 16)) video_slice = video.flatten() - video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) - video_slice = [round(x, 5) for x in video_slice.tolist()] - self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + self.assertEqual(len(video_slice), 5 * 3 * 16 * 16) + self.assertTrue(torch.is_tensor(video)) + self.assertTrue(video.dtype == torch.float32) - def test_inference_with_single_reference_image(self): + def test_inference_with_audio_embeds(self): device = "cpu" components = self.get_dummy_components() @@ -152,20 +158,21 @@ def test_inference_with_single_reference_image(self): pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - inputs["reference_images"] = Image.new("RGB", (16, 16)) - video = pipe(**inputs).frames[0] - self.assertEqual(video.shape, (17, 3, 16, 16)) - # fmt: off - expected_slice = [0.45247, 0.45214, 0.44874, 0.45314, 0.45171, 0.45299, 0.45428, 0.45317, 0.51378, 0.52658, 0.53361, 0.52303, 0.46204, 0.50435, 0.52555, 0.51342] - # fmt: on + batch_size = 1 + num_layers = 5 + freq_dim = 16 + num_frames = 170 + audio_embeds = torch.randn(batch_size, num_layers, freq_dim, num_frames) - video_slice = video.flatten() - video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) - video_slice = [round(x, 5) for x in video_slice.tolist()] - self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + inputs["audio"] = None + inputs["sampling_rate"] = None + inputs["audio_embeds"] = audio_embeds - def test_inference_with_multiple_reference_image(self): + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + def test_inference_with_different_sampling_rates(self): device = "cpu" components = self.get_dummy_components() @@ -174,18 +181,16 @@ def test_inference_with_multiple_reference_image(self): pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - inputs["reference_images"] = [[Image.new("RGB", (16, 16))] * 2] - video = pipe(**inputs).frames[0] - self.assertEqual(video.shape, (17, 3, 16, 16)) - # fmt: off - expected_slice = [0.45321, 0.45221, 0.44818, 0.45375, 0.45268, 0.4519, 0.45271, 0.45253, 0.51244, 0.52223, 0.51253, 0.51321, 0.50743, 0.51177, 0.51626, 0.50983] - # fmt: on + sampling_rate = 22050 + audio_length = 1.0 + audio = np.random.rand(int(sampling_rate * audio_length)).astype(np.float32) - video_slice = video.flatten() - video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) - video_slice = [round(x, 5) for x in video_slice.tolist()] - self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + inputs["audio"] = audio + inputs["sampling_rate"] = sampling_rate + + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): From dfb99d05348a7820802502d6a6adb6ccdb9c78b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 26 Sep 2025 14:23:46 +0300 Subject: [PATCH 122/131] up tests --- .../transformers/transformer_wan_s2v.py | 9 ++-- .../pipelines/wan/test_wan_speech_to_video.py | 52 +++++++++---------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index bf96f4d73687..4a7fc8abc273 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -267,9 +267,9 @@ def forward(self, features): class CausalAudioEncoder(nn.Module): - def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_audio_token=4, need_global=False): + def __init__(self, dim=5120, num_weighted_avg_layers=25, out_dim=2048, num_audio_token=4, need_global=False): super().__init__() - self.weighted_avg = WeightedAveragelayer(num_layers) + self.weighted_avg = WeightedAveragelayer(num_weighted_avg_layers) self.encoder = WanS2VMotionEncoder( in_dim=dim, hidden_dim=out_dim, num_attention_heads=num_audio_token, need_global=need_global ) @@ -530,6 +530,7 @@ def __init__( pose_embed_dim: int, patch_size: Tuple[int], enable_adain: bool, + num_weighted_avg_layers: int, ): super().__init__() @@ -539,7 +540,7 @@ def __init__( self.time_proj = nn.Linear(dim, time_proj_dim) self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") self.causal_audio_encoder = CausalAudioEncoder( - dim=audio_embed_dim, out_dim=dim, num_audio_token=4, need_global=enable_adain + dim=audio_embed_dim, num_weighted_avg_layers=num_weighted_avg_layers, out_dim=dim, num_audio_token=4, need_global=enable_adain ) self.pose_embedder = nn.Conv3d(pose_embed_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -863,6 +864,7 @@ def __init__( pose_dim: int = 16, ffn_dim: int = 13824, num_layers: int = 40, + num_weighted_avg_layers: int = 25, cross_attn_norm: bool = True, qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, @@ -911,6 +913,7 @@ def __init__( pose_embed_dim=pose_dim, patch_size=patch_size, enable_adain=enable_adain, + num_weighted_avg_layers=num_weighted_avg_layers, ) # 3. Transformer blocks diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py index 5230f31088eb..30b9b3717606 100644 --- a/tests/pipelines/wan/test_wan_speech_to_video.py +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import unittest import numpy as np @@ -26,7 +27,7 @@ WanSpeechToVideoPipeline, ) -from ...testing_utils import enable_full_determinism +from ...testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -64,7 +65,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") @@ -76,12 +77,17 @@ def get_dummy_components(self): in_channels=16, out_channels=16, text_dim=32, - freq_dim=16, + freq_dim=256, ffn_dim=32, num_layers=3, + num_weighted_avg_layers=5, cross_attn_norm=True, qk_norm="rms_norm_across_heads", rope_max_seq_len=32, + audio_dim=16, + audio_inject_layers=[0, 2], + enable_adain=True, + enable_framepack=True, ) torch.manual_seed(0) @@ -104,9 +110,10 @@ def get_dummy_inputs(self, device, seed=0): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - - height = 16 - width = 16 + # Use 64x64 so that after VAE downsampling (factor ~8) latent spatial size is 8x8, which matches + # the frame-packing conv kernel requirement. The largest kernel is (4, 8, 8) so we need at least 8x8 latents. + height = 64 + width = 64 image = Image.new("RGB", (width, height)) @@ -125,10 +132,12 @@ def get_dummy_inputs(self, device, seed=0): "guidance_scale": 4.5, "height": height, "width": width, - "num_frames_per_chunk": 5, + "num_frames_per_chunk": 4, + "num_chunks": 2, "max_sequence_length": 16, "output_type": "pt", "pose_video_path_or_url": None, + "init_first_frame": True, } return inputs @@ -142,14 +151,12 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames[0] - self.assertEqual(video.shape, (5, 3, 16, 16)) - - video_slice = video.flatten() - self.assertEqual(len(video_slice), 5 * 3 * 16 * 16) - self.assertTrue(torch.is_tensor(video)) - self.assertTrue(video.dtype == torch.float32) + expected_num_frames = inputs["num_frames_per_chunk"] * inputs["num_chunks"] + if not inputs["init_first_frame"]: + expected_num_frames -= 3 + self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"])) - def test_inference_with_audio_embeds(self): + def test_inference_with_pose(self): device = "cpu" components = self.get_dummy_components() @@ -158,19 +165,12 @@ def test_inference_with_audio_embeds(self): pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - - batch_size = 1 - num_layers = 5 - freq_dim = 16 - num_frames = 170 - audio_embeds = torch.randn(batch_size, num_layers, freq_dim, num_frames) - - inputs["audio"] = None - inputs["sampling_rate"] = None - inputs["audio_embeds"] = audio_embeds - + inputs["pose_video_path_or_url"] = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" video = pipe(**inputs).frames[0] - self.assertEqual(video.shape, (17, 3, 16, 16)) + expected_num_frames = inputs["num_frames_per_chunk"] * inputs["num_chunks"] + if not inputs["init_first_frame"]: + expected_num_frames -= 3 + self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"])) def test_inference_with_different_sampling_rates(self): device = "cpu" From b5421f3beb267bbc7cbc8e4eeeb329e4401b5c26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 26 Sep 2025 15:29:39 +0300 Subject: [PATCH 123/131] down --- .../pipelines/wan/test_wan_speech_to_video.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py index 30b9b3717606..f98a5a86ecbf 100644 --- a/tests/pipelines/wan/test_wan_speech_to_video.py +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -172,25 +172,6 @@ def test_inference_with_pose(self): expected_num_frames -= 3 self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"])) - def test_inference_with_different_sampling_rates(self): - device = "cpu" - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - - sampling_rate = 22050 - audio_length = 1.0 - audio = np.random.rand(int(sampling_rate * audio_length)).astype(np.float32) - - inputs["audio"] = audio - inputs["sampling_rate"] = sampling_rate - - video = pipe(**inputs).frames[0] - self.assertEqual(video.shape, (17, 3, 16, 16)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): From 0cbe32e5a7eb41891212ef4dca0e49c34919aef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 26 Sep 2025 15:57:23 +0300 Subject: [PATCH 124/131] Add deterministic audio generation and callback configuration test --- .../pipelines/wan/test_wan_speech_to_video.py | 45 ++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py index f98a5a86ecbf..2ad71cbad05f 100644 --- a/tests/pipelines/wan/test_wan_speech_to_video.py +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -29,7 +29,7 @@ from ...testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineTesterMixin, to_np enable_full_determinism() @@ -119,7 +119,9 @@ def get_dummy_inputs(self, device, seed=0): sampling_rate = 16000 audio_length = 0.5 - audio = np.random.rand(int(sampling_rate * audio_length)).astype(np.float32) + # Make audio generation deterministic by using a fixed seed + np_rng = np.random.RandomState(seed) + audio = np_rng.rand(int(sampling_rate * audio_length)).astype(np.float32) inputs = { "image": image, @@ -200,3 +202,42 @@ def test_float16_inference(self): ) def test_save_load_float16(self): pass + + def test_callback_cfg(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + if "guidance_scale" not in sig.parameters: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_increase_guidance(pipe, i, t, callback_kwargs): + pipe._guidance_scale += 1.0 + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # use cfg guidance because some pipelines modify the shape of the latents + # outside of the denoising loop + inputs["guidance_scale"] = 2.0 + inputs["callback_on_step_end"] = callback_increase_guidance + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + # For this pipeline, the total number of timesteps is multiplied by num_chunks + # since each chunk runs independently with its own denoising loop + expected_final_guidance = inputs["guidance_scale"] + (pipe.num_timesteps * inputs["num_chunks"]) + assert pipe.guidance_scale == expected_final_guidance From 685d86e654b71cfc09a7ed2aa7433f15271cb5f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 26 Sep 2025 16:01:09 +0300 Subject: [PATCH 125/131] up --- tests/pipelines/wan/test_wan_speech_to_video.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py index 2ad71cbad05f..f2f721a2d916 100644 --- a/tests/pipelines/wan/test_wan_speech_to_video.py +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -237,7 +237,9 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs): inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs _ = pipe(**inputs)[0] + # we increase the guidance scale by 1.0 at every step + # check that the guidance scale is increased by the number of scheduler timesteps + # accounts for models that modify the number of inference steps based on strength. # For this pipeline, the total number of timesteps is multiplied by num_chunks # since each chunk runs independently with its own denoising loop - expected_final_guidance = inputs["guidance_scale"] + (pipe.num_timesteps * inputs["num_chunks"]) - assert pipe.guidance_scale == expected_final_guidance + assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps * inputs["num_chunks"]) From 8a5bb49d06ef074319c49bb77f1999a0eae21948 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 26 Sep 2025 16:01:37 +0300 Subject: [PATCH 126/131] style --- src/diffusers/models/transformers/transformer_wan_s2v.py | 6 +++++- tests/pipelines/wan/test_wan_speech_to_video.py | 3 +-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 4a7fc8abc273..309aaf6e66f2 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -540,7 +540,11 @@ def __init__( self.time_proj = nn.Linear(dim, time_proj_dim) self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") self.causal_audio_encoder = CausalAudioEncoder( - dim=audio_embed_dim, num_weighted_avg_layers=num_weighted_avg_layers, out_dim=dim, num_audio_token=4, need_global=enable_adain + dim=audio_embed_dim, + num_weighted_avg_layers=num_weighted_avg_layers, + out_dim=dim, + num_audio_token=4, + need_global=enable_adain, ) self.pose_embedder = nn.Conv3d(pose_embed_dim, dim, kernel_size=patch_size, stride=patch_size) diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py index f2f721a2d916..6371ed040e45 100644 --- a/tests/pipelines/wan/test_wan_speech_to_video.py +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -29,7 +29,7 @@ from ...testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin enable_full_determinism() @@ -174,7 +174,6 @@ def test_inference_with_pose(self): expected_num_frames -= 3 self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"])) - @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass From 97c2125bcfce55c547e6c3dab782e59aa32b22be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 26 Sep 2025 16:02:47 +0300 Subject: [PATCH 127/131] up --- .../pipelines/wan/test_wan_speech_to_video.py | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py index 6371ed040e45..7396a151b3be 100644 --- a/tests/pipelines/wan/test_wan_speech_to_video.py +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -174,34 +174,6 @@ def test_inference_with_pose(self): expected_num_frames -= 3 self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"])) - @unittest.skip("Test not supported") - def test_attention_slicing_forward_pass(self): - pass - - @unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.") - def test_encode_prompt_works_in_isolation(self): - pass - - @unittest.skip("Batching is not yet supported with this pipeline") - def test_inference_batch_consistent(self): - pass - - @unittest.skip("Batching is not yet supported with this pipeline") - def test_inference_batch_single_identical(self): - return super().test_inference_batch_single_identical() - - @unittest.skip( - "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" - ) - def test_float16_inference(self): - pass - - @unittest.skip( - "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" - ) - def test_save_load_float16(self): - pass - def test_callback_cfg(self): sig = inspect.signature(self.pipeline_class.__call__) has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters @@ -242,3 +214,31 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs): # For this pipeline, the total number of timesteps is multiplied by num_chunks # since each chunk runs independently with its own denoising loop assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps * inputs["num_chunks"]) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip("Batching is not yet supported with this pipeline") + def test_inference_batch_consistent(self): + pass + + @unittest.skip("Batching is not yet supported with this pipeline") + def test_inference_batch_single_identical(self): + return super().test_inference_batch_single_identical() + + @unittest.skip( + "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" + ) + def test_float16_inference(self): + pass + + @unittest.skip( + "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" + ) + def test_save_load_float16(self): + pass From 2575d473aa8dd11b43b422e56acc9f4a52b7aca2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 29 Sep 2025 08:24:05 +0300 Subject: [PATCH 128/131] Use immutable default values --- src/diffusers/models/transformers/transformer_wan_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py index 309aaf6e66f2..4d810eeca630 100644 --- a/src/diffusers/models/transformers/transformer_wan_s2v.py +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -862,7 +862,7 @@ def __init__( text_dim: int = 4096, freq_dim: int = 256, audio_dim: int = 1024, - audio_inject_layers: List[int] = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + audio_inject_layers: Tuple[int] = (0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39), enable_adain: bool = True, adain_mode: str = "attn_norm", pose_dim: int = 16, From cb615bb073abce1307416112e0310deebc6eb146 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 17 Oct 2025 08:59:02 +0300 Subject: [PATCH 129/131] Refactor device handling in WanSpeechToVideoPipeline for consistency - Updated device references in audio encoding and pose video loading to use a unified device variable. - Enhanced image preprocessing to include a resize mode option for better handling of input dimensions. Co-authored-by: Ju Hoon Park --- .../pipelines/wan/pipeline_wan_s2v.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py index aa72f0fc24f3..6f78cec07442 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -343,7 +343,7 @@ def encode_audio( input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values # retrieve logits & take argmax - res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True) + res = self.audio_encoder(input_values.to(device), output_hidden_states=True) feat = torch.cat(res.hidden_states) feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) @@ -635,24 +635,22 @@ def prepare_latents( def load_pose_condition( self, pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std ): + device = self._execution_device + dtype = self.vae.dtype if pose_video is not None: padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2] - pose_video = pose_video.to(dtype=self.vae.dtype, device=self.vae.device) + pose_video = pose_video.to(dtype=dtype, device=device) pose_video = torch.cat( [ pose_video, - -torch.ones( - [1, 3, padding_frame_num, height, width], dtype=self.vae.dtype, device=self.vae.device - ), + -torch.ones([1, 3, padding_frame_num, height, width], dtype=dtype, device=device), ], dim=2, ) pose_video = torch.chunk(pose_video, num_chunks, dim=2) else: - pose_video = [ - -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=self.vae.dtype, device=self.vae.device) - ] + pose_video = [-torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=dtype, device=device)] # Vectorized processing: concatenate all chunks along batch dimension all_poses = torch.cat( @@ -886,7 +884,9 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim - image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + image = self.video_processor.preprocess( + image, height=height, width=width, resize_mode="resize_min_center_crop" + ).to(device, dtype=torch.float32) pose_video = None if pose_video_path_or_url is not None: From 54cfc71e80b9638ee237ef8800fdf06bce7c8877 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Tue, 21 Oct 2025 08:12:55 +0300 Subject: [PATCH 130/131] Update Wan-S2V model description and contributor info Added contributor information and enhanced model description. --- docs/source/en/api/pipelines/wan.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 0d6513d88648..a8db62e6c72c 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -244,6 +244,8 @@ export_to_video(output, "output.mp4", fps=16) *Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.* +This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). + The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, an audio, and a pose video. @@ -487,4 +489,4 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip ## WanPipelineOutput -[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput \ No newline at end of file +[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput From 5062ea7519ebe488f53ff663c7e50bf3506ea7af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Tue, 21 Oct 2025 08:14:02 +0300 Subject: [PATCH 131/131] Update Wan-S2V documentation with project page link Added project page link for Wan-S2V model and improved context. --- docs/source/en/api/pipelines/wan.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index a8db62e6c72c..116d957019fd 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -244,6 +244,8 @@ export_to_video(output, "output.mp4", fps=16) *Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.* +The project page: https://humanaigc.github.io/wan-s2v-webpage/ + This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, an audio, and a pose video.