In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download


pipe = WanVideoPipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", skip_download=True),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", skip_download=True),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu", skip_download=True),
    ],
)

  from .autonotebook import tqdm as notebook_tqdm


Loading models from: ./models/Wan-AI/Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth
    model_name: wan_video_text_encoder model_class: WanTextEncoder
    The following models are loaded: ['wan_video_text_encoder'].
Loading models from: ['./models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00003-of-00003-bf16.safetensors', './models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00002-of-00003-bf16.safetensors', './models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00001-of-00003-bf16.safetensors']
    model_name: wan_video_dit model_class: WanModel
        This model is initialized with extra kwargs: {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
    The following models are loaded: ['wan_video_dit'].
Loa

In [6]:
#read in mp4 video
mp4_path = "/data2/saikiran.tedla/hdrvideo/diff/stuttgart_input.mp4"
frames = VideoData(mp4_path, height=480, width=832).raw_data()
print("pipe device:", pipe.device)
input_video = pipe.preprocess_video(frames)
pipe.load_models_to_device(["vae"])

# ensure the VAE is actually on the GPU in the right dtype
pipe.vae = pipe.vae.to(device=pipe.device)
pipe.vae.eval()  # optional but good practice

input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=False).to(dtype=pipe.torch_dtype, device=pipe.device)
output_video = pipe.vae.decode(input_latents, device=pipe.device, tiled=False).to(dtype=torch.float32, device="cpu")
output_video = pipe.vae_output_to_video(output_video)

import numpy as np
import math
def psnr(img1, img2):
    """Compute PSNR between two numpy arrays (H,W,C)."""
    mse = np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2)
    if mse == 0:
        return float("inf")  # identical images
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

# compute per-frame PSNR
psnr_values = []
for i, (f1, f2) in enumerate(zip(frames, output_video)):
    arr1 = np.array(f1)  # convert PIL → numpy
    arr2 = np.array(f2)
    value = psnr(arr1, arr2)
    psnr_values.append(value)
    print(f"Frame {i}: PSNR = {value:.2f} dB")


# save the output video
save_video(frames, "input_video.mp4", fps=15)
save_video(output_video, "output_video.mp4", fps=15)

pipe device: cuda
Frame 0: PSNR = 28.68 dB
Frame 1: PSNR = 29.10 dB
Frame 2: PSNR = 26.85 dB
Frame 3: PSNR = 27.01 dB
Frame 4: PSNR = 28.23 dB
Frame 5: PSNR = 25.49 dB
Frame 6: PSNR = 26.00 dB
Frame 7: PSNR = 26.71 dB
Frame 8: PSNR = 27.86 dB
Frame 9: PSNR = 24.79 dB
Frame 10: PSNR = 26.67 dB
Frame 11: PSNR = 26.84 dB
Frame 12: PSNR = 28.00 dB
Frame 13: PSNR = 26.01 dB
Frame 14: PSNR = 27.12 dB
Frame 15: PSNR = 26.53 dB
Frame 16: PSNR = 27.31 dB
Frame 17: PSNR = 27.66 dB
Frame 18: PSNR = 29.47 dB
Frame 19: PSNR = 29.84 dB
Frame 20: PSNR = 30.41 dB
Frame 21: PSNR = 30.00 dB
Frame 22: PSNR = 30.83 dB
Frame 23: PSNR = 31.48 dB
Frame 24: PSNR = 32.39 dB
Frame 25: PSNR = 29.11 dB
Frame 26: PSNR = 30.09 dB
Frame 27: PSNR = 30.56 dB
Frame 28: PSNR = 31.82 dB
Frame 29: PSNR = 29.77 dB
Frame 30: PSNR = 30.26 dB
Frame 31: PSNR = 29.63 dB
Frame 32: PSNR = 32.43 dB
Frame 33: PSNR = 32.16 dB
Frame 34: PSNR = 33.94 dB
Frame 35: PSNR = 34.26 dB
Frame 36: PSNR = 34.90 dB
Frame 37: PSNR = 33.89 dB
Fram

Saving video: 100%|██████████| 49/49 [00:00<00:00, 113.13it/s]
Saving video: 100%|██████████| 49/49 [00:00<00:00, 108.59it/s]
