Skip to content

Commit d9e11dd

Browse files
authored
Merge pull request #1 from tolgacangoz/integrations/wan2.2-s2v
Integrations/wan2.2 s2v
2 parents 022b4fd + bb55dcc commit d9e11dd

File tree

19 files changed

+3123
-56
lines changed

19 files changed

+3123
-56
lines changed

docs/source/en/api/pipelines/wan.md

Lines changed: 147 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
4040
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
4141
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
4242
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
43+
- [Wan 2.2 S2V 14B](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B-Diffusers)
4344

4445
> [!TIP]
4546
> Click on the Wan models in the right sidebar for more examples of video generation.
@@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained(
9596
pipeline.to("cuda")
9697

9798
prompt = """
98-
The camera rushes from far to near in a low-angle shot,
99-
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
100-
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
101-
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
99+
The camera rushes from far to near in a low-angle shot,
100+
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
101+
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
102+
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
102103
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
103104
"""
104105
negative_prompt = """
105-
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
106-
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
106+
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
107+
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
107108
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
108109
"""
109110

@@ -150,15 +151,15 @@ pipeline.transformer = torch.compile(
150151
)
151152

152153
prompt = """
153-
The camera rushes from far to near in a low-angle shot,
154-
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
155-
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
156-
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
154+
The camera rushes from far to near in a low-angle shot,
155+
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
156+
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
157+
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
157158
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
158159
"""
159160
negative_prompt = """
160-
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
161-
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
161+
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
162+
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
162163
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
163164
"""
164165

@@ -236,6 +237,129 @@ export_to_video(output, "output.mp4", fps=16)
236237
</hfoption>
237238
</hfoptions>
238239

240+
241+
### Wan-S2V: Audio-Driven Cinematic Video Generation
242+
243+
[Wan-S2V](https://huggingface.co/papers/2508.18621) by the Wan Team.
244+
245+
*Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.*
246+
247+
The project page: https://humanaigc.github.io/wan-s2v-webpage/
248+
249+
This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
250+
251+
The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, an audio, and a pose video.
252+
253+
<hfoptions id="S2V usage">
254+
<hfoption id="usage">
255+
256+
```python
257+
import numpy as np, math
258+
import torch
259+
from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline
260+
from diffusers.utils import export_to_merged_video_audio, load_image, load_audio, load_video, export_to_video
261+
from transformers import Wav2Vec2ForCTC
262+
import requests
263+
from PIL import Image
264+
from io import BytesIO
265+
266+
267+
model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers"
268+
audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
269+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
270+
pipe = WanSpeechToVideoPipeline.from_pretrained(
271+
model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16
272+
)
273+
pipe.to("cuda")
274+
275+
headers = {"User-Agent": "Mozilla/5.0"}
276+
url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg"
277+
resp = requests.get(url, headers=headers, timeout=30)
278+
image = Image.open(BytesIO(resp.content))
279+
280+
audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3")
281+
#pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4"
282+
283+
def get_size_less_than_area(height,
284+
width,
285+
target_area=1024 * 704,
286+
divisor=64):
287+
if height * width <= target_area:
288+
# If the original image area is already less than or equal to the target,
289+
# no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
290+
max_upper_area = target_area
291+
min_scale = 0.1
292+
max_scale = 1.0
293+
else:
294+
# Resize to fit within the target area and then pad to multiples of `divisor`
295+
max_upper_area = target_area # Maximum allowed total pixel count after padding
296+
d = divisor - 1
297+
b = d * (height + width)
298+
a = height * width
299+
c = d**2 - max_upper_area
300+
301+
# Calculate scale boundaries using quadratic equation
302+
min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied
303+
max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding
304+
305+
# We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
306+
# Use binary search-like iteration to find this scale
307+
find_it = False
308+
for i in range(100):
309+
scale = max_scale - (max_scale - min_scale) * i / 100
310+
new_height, new_width = int(height * scale), int(width * scale)
311+
312+
# Pad to make dimensions divisible by 64
313+
pad_height = (64 - new_height % 64) % 64
314+
pad_width = (64 - new_width % 64) % 64
315+
pad_top = pad_height // 2
316+
pad_bottom = pad_height - pad_top
317+
pad_left = pad_width // 2
318+
pad_right = pad_width - pad_left
319+
320+
padded_height, padded_width = new_height + pad_height, new_width + pad_width
321+
322+
if padded_height * padded_width <= max_upper_area:
323+
find_it = True
324+
break
325+
326+
if find_it:
327+
return padded_height, padded_width
328+
else:
329+
# Fallback: calculate target dimensions based on aspect ratio and divisor alignment
330+
aspect_ratio = width / height
331+
target_width = int(
332+
(target_area * aspect_ratio)**0.5 // divisor * divisor)
333+
target_height = int(
334+
(target_area / aspect_ratio)**0.5 // divisor * divisor)
335+
336+
# Ensure the result is not larger than the original resolution
337+
if target_width >= width or target_height >= height:
338+
target_width = int(width // divisor * divisor)
339+
target_height = int(height // divisor * divisor)
340+
341+
return target_height, target_width
342+
343+
height, width = get_size_less_than_area(first_frame.height, first_frame.width, 480*832)
344+
345+
prompt = "Einstein singing a song."
346+
347+
output = pipe(
348+
prompt=prompt, image=image, audio=audio, sampling_rate=sampling_rate,
349+
height=height, width=width, num_frames_per_chunk=80,
350+
#pose_video_path_or_url=pose_video_path_or_url,
351+
).frames[0]
352+
export_to_video(output, "output.mp4", fps=16)
353+
354+
# Lastly, we need to merge the video and audio into a new video, with the duration set to
355+
# the shorter of the two and overwrite the original video file.
356+
export_to_merged_video_audio("output.mp4", "audio.mp3")
357+
```
358+
359+
</hfoption>
360+
</hfoptions>
361+
362+
239363
### Any-to-Video Controllable Generation
240364

241365
Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include:
@@ -281,10 +405,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
281405

282406
# use "steamboat willie style" to trigger the LoRA
283407
prompt = """
284-
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
285-
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
286-
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
287-
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
408+
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
409+
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
410+
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
411+
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
288412
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
289413
"""
290414

@@ -353,6 +477,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
353477
- all
354478
- __call__
355479

480+
## WanSpeechToVideoPipeline
481+
482+
[[autodoc]] WanSpeechToVideoPipeline
483+
- all
484+
- __call__
485+
356486
## WanVideoToVideoPipeline
357487

358488
[[autodoc]] WanVideoToVideoPipeline
@@ -361,4 +491,4 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
361491

362492
## WanPipelineOutput
363493

364-
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
494+
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput

scripts/convert_wan_to_diffusers.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,22 @@
66
from accelerate import init_empty_weights
77
from huggingface_hub import hf_hub_download, snapshot_download
88
from safetensors.torch import load_file
9-
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
9+
from transformers import (
10+
AutoProcessor,
11+
AutoTokenizer,
12+
CLIPVisionModelWithProjection,
13+
UMT5EncoderModel,
14+
Wav2Vec2ForCTC,
15+
Wav2Vec2Processor,
16+
)
1017

1118
from diffusers import (
1219
AutoencoderKLWan,
1320
UniPCMultistepScheduler,
1421
WanImageToVideoPipeline,
1522
WanPipeline,
23+
WanS2VTransformer3DModel,
24+
WanSpeechToVideoPipeline,
1625
WanTransformer3DModel,
1726
WanVACEPipeline,
1827
WanVACETransformer3DModel,
@@ -105,8 +114,59 @@
105114
"after_proj": "proj_out",
106115
}
107116

117+
S2V_TRANSFORMER_KEYS_RENAME_DICT = {
118+
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
119+
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
120+
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
121+
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
122+
"time_projection.1": "condition_embedder.time_proj",
123+
"head.modulation": "scale_shift_table",
124+
"head.head": "proj_out",
125+
"modulation": "scale_shift_table",
126+
"ffn.0": "ffn.net.0.proj",
127+
"ffn.2": "ffn.net.2",
128+
# Hack to swap the layer names
129+
# The original model calls the norms in following order: norm1, norm3, norm2
130+
# We convert it to: norm1, norm2, norm3
131+
"norm2": "norm__placeholder",
132+
"norm3": "norm2",
133+
"norm__placeholder": "norm3",
134+
# Add attention component mappings
135+
"self_attn.q": "attn1.to_q",
136+
"self_attn.k": "attn1.to_k",
137+
"self_attn.v": "attn1.to_v",
138+
"self_attn.o": "attn1.to_out.0",
139+
"self_attn.norm_q": "attn1.norm_q",
140+
"self_attn.norm_k": "attn1.norm_k",
141+
"cross_attn.q": "attn2.to_q",
142+
"cross_attn.k": "attn2.to_k",
143+
"cross_attn.v": "attn2.to_v",
144+
"cross_attn.o": "attn2.to_out.0",
145+
"cross_attn.norm_q": "attn2.norm_q",
146+
"cross_attn.norm_k": "attn2.norm_k",
147+
"attn2.to_k_img": "attn2.add_k_proj",
148+
"attn2.to_v_img": "attn2.add_v_proj",
149+
"attn2.norm_k_img": "attn2.norm_added_k",
150+
# S2V-specific audio component mappings
151+
"casual_audio_encoder.encoder.conv2.conv": "condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv",
152+
"casual_audio_encoder.encoder.conv3.conv": "condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv",
153+
"casual_audio_encoder.weights": "condition_embedder.causal_audio_encoder.weighted_avg.weights",
154+
# Pose condition encoder mappings
155+
"cond_encoder.weight": "condition_embedder.pose_embedder.weight",
156+
"cond_encoder.bias": "condition_embedder.pose_embedder.bias",
157+
"trainable_cond_mask": "trainable_condition_mask",
158+
"patch_embedding": "motion_in.patch_embedding",
159+
# Audio injector attention mappings - convert original q/k/v/o format to diffusers format
160+
**{
161+
f"audio_injector.injector.{i}.{src}": f"audio_injector.injector.{i}.{dst}"
162+
for i in range(12)
163+
for src, dst in [("q", "to_q"), ("k", "to_k"), ("v", "to_v"), ("o", "to_out.0")]
164+
},
165+
}
166+
108167
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
109168
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
169+
S2V_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
110170

111171

112172
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -364,6 +424,36 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
364424
}
365425
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
366426
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
427+
elif model_type == "Wan2.2-S2V-14B":
428+
config = {
429+
"model_id": "Wan-AI/Wan2.2-S2V-14B",
430+
"diffusers_config": {
431+
"added_kv_proj_dim": None,
432+
"attention_head_dim": 128,
433+
"cross_attn_norm": True,
434+
"eps": 1e-06,
435+
"ffn_dim": 13824,
436+
"freq_dim": 256,
437+
"in_channels": 16,
438+
"num_attention_heads": 40,
439+
"num_layers": 40,
440+
"out_channels": 16,
441+
"patch_size": [1, 2, 2],
442+
"qk_norm": "rms_norm_across_heads",
443+
"text_dim": 4096,
444+
"audio_dim": 1024,
445+
"audio_inject_layers": [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
446+
"enable_adain": True,
447+
"adain_mode": "attn_norm",
448+
"pose_dim": 16,
449+
"enable_framepack": True,
450+
"framepack_drop_mode": "padd",
451+
"add_last_motion": True,
452+
"zero_timestep": True,
453+
},
454+
}
455+
RENAME_DICT = S2V_TRANSFORMER_KEYS_RENAME_DICT
456+
SPECIAL_KEYS_REMAP = S2V_TRANSFORMER_SPECIAL_KEYS_REMAP
367457
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
368458

369459

@@ -380,7 +470,9 @@ def convert_transformer(model_type: str, stage: str = None):
380470
original_state_dict = load_sharded_safetensors(model_dir)
381471

382472
with init_empty_weights():
383-
if "VACE" not in model_type:
473+
if "S2V" in model_type:
474+
transformer = WanS2VTransformer3DModel.from_config(diffusers_config)
475+
elif "VACE" not in model_type:
384476
transformer = WanTransformer3DModel.from_config(diffusers_config)
385477
else:
386478
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
@@ -926,7 +1018,7 @@ def get_args():
9261018
if __name__ == "__main__":
9271019
args = get_args()
9281020

929-
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
1021+
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "S2V" not in args.model_type:
9301022
transformer = convert_transformer(args.model_type, stage="high_noise_model")
9311023
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
9321024
else:
@@ -942,7 +1034,7 @@ def get_args():
9421034
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
9431035
if "FLF2V" in args.model_type:
9441036
flow_shift = 16.0
945-
elif "TI2V" in args.model_type:
1037+
elif "TI2V" in args.model_type or "S2V" in args.model_type:
9461038
flow_shift = 5.0
9471039
else:
9481040
flow_shift = 3.0
@@ -1016,6 +1108,22 @@ def get_args():
10161108
vae=vae,
10171109
scheduler=scheduler,
10181110
)
1111+
elif "S2V" in args.model_type:
1112+
audio_encoder = Wav2Vec2ForCTC.from_pretrained(
1113+
"Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english"
1114+
)
1115+
audio_processor = Wav2Vec2Processor.from_pretrained(
1116+
"Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english"
1117+
)
1118+
pipe = WanSpeechToVideoPipeline(
1119+
transformer=transformer,
1120+
text_encoder=text_encoder,
1121+
tokenizer=tokenizer,
1122+
vae=vae,
1123+
scheduler=scheduler,
1124+
audio_encoder=audio_encoder,
1125+
audio_processor=audio_processor,
1126+
)
10191127
else:
10201128
pipe = WanPipeline(
10211129
transformer=transformer,

0 commit comments

Comments
 (0)