In [1]:
%load_ext autoreload
%autoreload 2

In [40]:
import torch
import numpy as np
import pandas as pd
import os
import librosa
import cv2

from pathlib import Path
from random import shuffle
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData

In [3]:
vid_path = Path("dataset/train")
metadata = pd.read_csv("dataset/metadata.csv")
masks = Path("dataset/lip_masks")

In [26]:
# Load latents
latents_paths = [video_path for video_path in vid_path.iterdir() if video_path.suffix == ".pth"]
# shuffle(latents_paths)

In [5]:
model_manager = ModelManager(device="cpu")
model_manager.load_models(
    [
        "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
        "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
        "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
    ],
    torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)

Loading models from: models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors
    model_name: wan_video_dit model_class: WanModel
        This model is initialized with extra kwargs: {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
    The following models are loaded: ['wan_video_dit'].
Loading models from: models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
    model_name: wan_video_text_encoder model_class: WanTextEncoder
    The following models are loaded: ['wan_video_text_encoder'].
Loading models from: models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
    model_name: wan_video_vae model_class: WanVideoVAE
    The following models are loaded: ['wan_video_vae'].
Using wan_video_text_encoder from models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth.
Using wan_video_dit from models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pyt

In [54]:
# Load 5 latents
latents = []
audio_list = []

for latent_path in latents_paths[12:14]:
    latents.append(torch.load(latent_path)["latents"])
    name = latent_path.name.replace(".tensors.pth", "")

    caption = metadata[metadata["file_name"] == name]["text"].values[0]
    # Load audio
    audio_path = f"dataset/audio/{name.replace('.mp4', '.wav')}"

    print(os.path.exists(audio_path))

    audio_data, sr = librosa.load(audio_path, sr=None)
    
    # Add audio to list
    
    audio_list.append({
        'name': name,
        'audio': audio_data,
        'caption': caption
    })
    
latents = torch.stack(latents).to(pipe.device)

True
True


In [55]:
# play audio
print(audio_list[0]['caption'])

I'll see you in the next video.


In [59]:
# Play the audio
from IPython.display import Audio
display(Audio(audio_list[1]['audio'], rate=41000))

In [29]:
pipe.vae.to(pipe.device)

WanVideoVAE(
  (model): VideoVAE_(
    (encoder): Encoder3d(
      (conv1): AutoWrappedModule(
        (module): CausalConv3d(3, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      )
      (downsamples): Sequential(
        (0): ResidualBlock(
          (residual): Sequential(
            (0): AutoWrappedModule(
              (module): RMS_norm()
            )
            (1): AutoWrappedModule(
              (module): SiLU()
            )
            (2): AutoWrappedModule(
              (module): CausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))
            )
            (3): AutoWrappedModule(
              (module): RMS_norm()
            )
            (4): AutoWrappedModule(
              (module): SiLU()
            )
            (5): AutoWrappedModule(
              (module): Dropout(p=0.0, inplace=False)
            )
            (6): AutoWrappedModule(
              (module): CausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))
            )
          )
   

In [30]:
out = pipe.vae.decode(latents, device='cuda')

In [31]:
for i in range(out.shape[0]):
    video = out[i].permute(1, 2, 3, 0).cpu().float()
    save_video(video, f"video_{i}.mp4", fps=30)



Saving video: 100%|██████████| 57/57 [00:01<00:00, 51.53it/s]
Saving video: 100%|██████████| 57/57 [00:01<00:00, 53.93it/s]
