In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
from torchmetrics import PeakSignalNoiseRatio
import numpy as np



pipe = WanVideoPipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda:5",
    model_configs=[
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", skip_download=True),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", skip_download=True),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu", skip_download=True),
    ],
)

pipe.load_models_to_device(["vae"])

# ensure the VAE is actually on the GPU in the right dtype
pipe.vae = pipe.vae.to(device=pipe.device)
pipe.vae.eval()  # optional but good practice

  from .autonotebook import tqdm as notebook_tqdm


Loading models from: ./models/Wan-AI/Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth
    model_name: wan_video_text_encoder model_class: WanTextEncoder
    The following models are loaded: ['wan_video_text_encoder'].
Loading models from: ['./models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00003-of-00003-bf16.safetensors', './models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00002-of-00003-bf16.safetensors', './models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00001-of-00003-bf16.safetensors']
    model_name: wan_video_dit model_class: WanModel
        This model is initialized with extra kwargs: {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
    The following models are loaded: ['wan_video_dit'].
Loa

WanVideoVAE38(
  (model): VideoVAE38_(
    (encoder): Encoder3d_38(
      (conv1): CausalConv3d(12, 160, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (downsamples): Sequential(
        (0): Down_ResidualBlock(
          (avg_shortcut): AvgDown3D()
          (downsamples): Sequential(
            (0): ResidualBlock(
              (residual): Sequential(
                (0): RMS_norm()
                (1): SiLU()
                (2): CausalConv3d(160, 160, kernel_size=(3, 3, 3), stride=(1, 1, 1))
                (3): RMS_norm()
                (4): SiLU()
                (5): Dropout(p=0.0, inplace=False)
                (6): CausalConv3d(160, 160, kernel_size=(3, 3, 3), stride=(1, 1, 1))
              )
              (shortcut): Identity()
            )
            (1): ResidualBlock(
              (residual): Sequential(
                (0): RMS_norm()
                (1): SiLU()
                (2): CausalConv3d(160, 160, kernel_size=(3, 3, 3), stride=(1, 1, 1))
                (3):

In [3]:
from diffsynth.trainers.stuttgart_dataset import StuttgartDataset
import numpy as np
def average_frame_psnr(pred: torch.Tensor, target: torch.Tensor, data_range: float = 1.0) -> torch.Tensor:
    """
    Computes the average per-frame PSNR between two HDR videos.

    Args:
        pred (torch.Tensor): Predicted video of shape [B, C, T, H, W].
        target (torch.Tensor): Ground-truth video of shape [B, C, T, H, W].
        data_range (float): Max value of the signal (1.0 if normalized).

    Returns:
        torch.Tensor: Scalar tensor with average PSNR over frames.
    """
    assert pred.shape == target.shape, "Prediction and target must have the same shape"
    B, C, T, H, W = pred.shape
    psnr_metric = PeakSignalNoiseRatio(data_range=data_range)
    
    psnrs = []
    for t in range(T):
        # Extract frame t: shape [B, C, H, W]
        psnr_val = psnr_metric(pred[:, :, t], target[:, :, t])
        psnrs.append(psnr_val)

    return torch.stack(psnrs).mean()

dataset = StuttgartDataset(
    base_path="/data2/saikiran.tedla/hdrvideo/diff/data/stuttgart/carousel_fireworks_02",
    repeat=1,
    main_data_operator=StuttgartDataset.default_video_operator(
        base_path="/data2/saikiran.tedla/hdrvideo/diff/data/stuttgart/carousel_fireworks_02",
        max_pixels=1280*720,
        height=480,
        width=832,
        height_division_factor=16,
        width_division_factor=16,
        num_frames=13,
        time_division_factor=4,
        time_division_remainder=1,
    ),
    mode = "hdr_and_brackets"
)

# Get a sample from the dataset
data = dataset[0]
bracket_video = data["video"]
#convert bracket video (list of PIL images) to tensor
bracket_video = torch.stack([torch.from_numpy(np.array(img)).permute(2,0,1) for img in bracket_video], dim=0).float() / 255.0  # shape (T, C, H, W)
bracket_video = bracket_video.permute(1,0,2,3)  # shape (C, T, H, W)
hdr_video = torch.from_numpy(data["hdr_video"])
hdr_video = hdr_video.permute(3,0,1,2)  # shape (C, T, H, W)
#reshape hdr_video to same H,W as bracket_video
hdr_video = torch.nn.functional.interpolate(hdr_video, size=bracket_video.shape[2:], mode="bilinear", align_corners=False).squeeze(0)
bracket_video = bracket_video.unsqueeze(0).to(pipe.device).to(torch.bfloat16)
hdr_video = hdr_video.unsqueeze(0).to(pipe.device).to(torch.bfloat16)
#repeat the last hdr_video_frame so we get 5 frames
hdr_video = torch.cat([hdr_video, hdr_video[:,:, -1:]], dim=2)/torch.max(hdr_video)

print("bracket video shape:", bracket_video.shape)
print("hdr video shape:", hdr_video.shape)

bracket video shape: torch.Size([1, 3, 13, 480, 832])
hdr video shape: torch.Size([1, 3, 5, 480, 832])


In [None]:
#read in mp4 video
# mp4_path = "/data2/saikiran.tedla/hdrvideo/diff/stuttgart_input.mp4"
# frames = VideoData(mp4_path, height=480, width=832).raw_data()
#input_video = pipe.preprocess_video(frames)

def encode_and_decode(video):
    with torch.no_grad():
        video = video * 2.0 - 1.0  # scale to [-1, 1]
        input_latents = pipe.vae.encode(video, device=pipe.device, tiled=False).to(dtype=pipe.torch_dtype, device=pipe.device)
        output_video = pipe.vae.decode(input_latents, device=pipe.device, tiled=False).to(dtype=torch.float32, device="cpu")
        output_video = (output_video + 1.0) / 2.0  # scale to [0, 1]
        #clip output video to [0,1]
        output_video = torch.clamp(output_video, 0.0, 1.0)
        
    #output_video = pipe.vae_output_to_video(output_video)
    return output_video

#encode and decode the bracket video
decoded_bracket_video = encode_and_decode(bracket_video)
#encode and decode the hdr video
decoded_hdr_video = encode_and_decode(hdr_video)

#compute psnr between decoded bracket video and original bracket video
psnr_bracket = average_frame_psnr(decoded_bracket_video, bracket_video.to(dtype=torch.float32, device="cpu"), data_range=1.0)
#compute psnr between decoded hdr video and original hdr video
psnr_hdr = average_frame_psnr(decoded_hdr_video, hdr_video.to(dtype=torch.float32, device="cpu"), data_range=1.0)

print(f"PSNR between decoded bracket video and original bracket video: {psnr_bracket.item():.2f} dB")
print(f"PSNR between decoded HDR video and original HDR video: {psnr_hdr.item():.2f} dB")

# Save the videos for visual inspection
from utils import output_hdr_video
decoded_bracket_video = decoded_bracket_video.permute(0,2,1,3,4)  # (B, T, C, H, W)
decoded_hdr_video = decoded_hdr_video.permute(0,2,1,3,4)  # (B, T, C, H, W)
output_hdr_video(decoded_bracket_video[0], "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/decoded_bracket_video")
output_hdr_video(decoded_hdr_video[0], "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/decoded_hdr_video")
output_hdr_video(bracket_video[0].permute(1,0,2,3).to(dtype=torch.float32, device="cpu"), "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/original_bracket_video")
output_hdr_video(hdr_video[0].permute(1,0,2,3).to(dtype=torch.float32, device="cpu"), "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/original_hdr_video")



PSNR between decoded bracket video and original bracket video: 16.51 dB
PSNR between decoded HDR video and original HDR video: 32.93 dB
Frame shape: (480, 832, 3)
Mean of frame: 0.3200551
Frame shape: (480, 832, 3)
Mean of frame: 0.31974134
Frame shape: (480, 832, 3)
Mean of frame: 0.32709792
Frame shape: (480, 832, 3)
Mean of frame: 0.3599721
Frame shape: (480, 832, 3)
Mean of frame: 0.33332375
Frame shape: (480, 832, 3)
Mean of frame: 0.026476841
Frame shape: (480, 832, 3)
Mean of frame: 0.026570586
Frame shape: (480, 832, 3)
Mean of frame: 0.03343632
Frame shape: (480, 832, 3)
Mean of frame: 0.027944708
Frame shape: (480, 832, 3)
Mean of frame: 0.66885054
Frame shape: (480, 832, 3)
Mean of frame: 0.68392664
Frame shape: (480, 832, 3)
Mean of frame: 0.70576483
Frame shape: (480, 832, 3)
Mean of frame: 0.68491715
Frame shape: (480, 832, 3)
Mean of frame: 0.00025634698
Frame shape: (480, 832, 3)
Mean of frame: 0.00024913612
Frame shape: (480, 832, 3)
Mean of frame: 0.00045346507
Frame 