In [None]:
import torch
import numpy as np
import pandas as pd
import os
import cv2
from pathlib import Path

from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData

In [2]:
vid_path = Path("dataset/train")
metadata = pd.read_csv("dataset/metadata.csv")
masks = Path("dataset/lip_masks")

In [3]:
# Load latents
latents_paths = [video_path for video_path in vid_path.iterdir() if video_path.suffix == ".pth"]

In [None]:
model_manager = ModelManager(device="cpu")
model_manager.load_models(
    [
        "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
        "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
        "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
    ],
    torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)

In [5]:
# Load 5 latents
latents = torch.stack([torch.load(latent_path)["latents"] for latent_path in latents_paths[:5]]).to(pipe.device)

In [None]:
pipe.vae.to(pipe.device)

In [9]:
out = pipe.vae.decode(latents, device='cuda')

In [24]:
for i in range(out.shape[0]):
    video = out[i].permute(1, 2, 3, 0).cpu().float()
    save_video(video, f"video_{i}.mp4", fps=30)



Saving video: 100%|██████████| 81/81 [00:01<00:00, 75.11it/s]
Saving video: 100%|██████████| 81/81 [00:01<00:00, 76.19it/s]
Saving video: 100%|██████████| 81/81 [00:01<00:00, 76.22it/s]
Saving video: 100%|██████████| 81/81 [00:01<00:00, 75.79it/s]
Saving video: 100%|██████████| 81/81 [00:01<00:00, 75.99it/s]


: 