In [109]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [110]:
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
import numpy as np
from utils import process_bracketed_video
from utils import output_hdr_video, process_bracketed_video
from utils import average_frame_psnr



pipe = WanVideoPipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda:5",
    model_configs=[
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", skip_download=True),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", skip_download=True),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu", skip_download=True),
    ],
)

pipe.load_models_to_device(["vae"])

# ensure the VAE is actually on the GPU in the right dtype
pipe.vae = pipe.vae.to(device=pipe.device)
pipe.vae.eval()  # optional but good practice

Loading models from: ./models/Wan-AI/Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth


    model_name: wan_video_text_encoder model_class: WanTextEncoder
    The following models are loaded: ['wan_video_text_encoder'].
Loading models from: ['./models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00003-of-00003-bf16.safetensors', './models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00002-of-00003-bf16.safetensors', './models/Wan-AI/Wan2.2-TI2V-5B/diffusion_pytorch_model-00001-of-00003-bf16.safetensors']
    model_name: wan_video_dit model_class: WanModel
        This model is initialized with extra kwargs: {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
    The following models are loaded: ['wan_video_dit'].
Loading models from: ./models/Wan-AI/Wan2.2-TI2V-5B/Wan2.2_VAE.pth
    model_name: wan_

WanVideoVAE38(
  (model): VideoVAE38_(
    (encoder): Encoder3d_38(
      (conv1): CausalConv3d(12, 160, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (downsamples): Sequential(
        (0): Down_ResidualBlock(
          (avg_shortcut): AvgDown3D()
          (downsamples): Sequential(
            (0): ResidualBlock(
              (residual): Sequential(
                (0): RMS_norm()
                (1): SiLU()
                (2): CausalConv3d(160, 160, kernel_size=(3, 3, 3), stride=(1, 1, 1))
                (3): RMS_norm()
                (4): SiLU()
                (5): Dropout(p=0.0, inplace=False)
                (6): CausalConv3d(160, 160, kernel_size=(3, 3, 3), stride=(1, 1, 1))
              )
              (shortcut): Identity()
            )
            (1): ResidualBlock(
              (residual): Sequential(
                (0): RMS_norm()
                (1): SiLU()
                (2): CausalConv3d(160, 160, kernel_size=(3, 3, 3), stride=(1, 1, 1))
                (3):

In [163]:
from diffsynth.trainers.stuttgart_dataset import StuttgartDataset
import numpy as np


dataset = StuttgartDataset(
    base_path="/data2/saikiran.tedla/hdrvideo/diff/data/stuttgart/carousel_fireworks_02",
    repeat=1,
    main_data_operator=StuttgartDataset.default_video_operator(
        base_path="/data2/saikiran.tedla/hdrvideo/diff/data/stuttgart/carousel_fireworks_02",
        max_pixels=1280*720,
        height=480,
        width=832,
        height_division_factor=16,
        width_division_factor=16,
        num_frames=13,
        time_division_factor=4,
        time_division_remainder=1,
    ),
    mode = "hdr_and_brackets"
)

data = dataset[50]
bracket_video = data["video"]
#convert bracket video (list of PIL images) to tensor
bracket_video = torch.stack([torch.from_numpy(np.array(img)).permute(2,0,1) for img in bracket_video], dim=0).float() / 255.0  # shape (T, C, H, W)
hdr_video = torch.from_numpy(data["hdr_video"])
hdr_video = hdr_video.permute(0,3,1,2)  # shape (C, T, H, W)
#reshape hdr_video to same H,W as bracket_video
hdr_video = torch.nn.functional.interpolate(hdr_video, size=bracket_video.shape[2:], mode="bilinear", align_corners=False).squeeze(0)
bracket_video = bracket_video.unsqueeze(0).to(pipe.device).to(torch.bfloat16) 
hdr_video = hdr_video.unsqueeze(0).to(pipe.device).to(torch.bfloat16)
#repeat the last hdr_video_frame so we get 5 frames
hdr_video_padded = torch.cat([hdr_video[:,0:1], hdr_video], dim=1)

normal_exposure = bracket_video[:, 1:5] #EV 0
low_exposure = bracket_video[:, 5:9] #EV -4
high_exposure = bracket_video[:, 9:13] #EV +4

normal_exposure_padded = torch.cat([normal_exposure[:,0:1], normal_exposure], dim=1)
low_exposure_padded = torch.cat([low_exposure[:,0:1], low_exposure], dim=1)
high_exposure_padded = torch.cat([high_exposure[:,0:1], high_exposure], dim=1)
#bracket_video[:, 5:9] = bracket_video[:, 5:9] ** (1/2)  # apply gamma correction to the normal exposure frames
#bracket_video[:, 5:9] = bracket_video[:, 5:9] ** (1/2.2)  # apply gamma correction to the normal exposure frames

bracket_with_pad = torch.cat([normal_exposure[:,0:1], normal_exposure, 
                                low_exposure[:, 0:1], low_exposure[:, 0:1], low_exposure[:, 0:1], low_exposure[:, 0:1], low_exposure,
                                high_exposure[:, 0:1], high_exposure[:, 0:1], high_exposure[:, 0:1], high_exposure[:, 0:1], high_exposure
                                ], dim=1)

RouteByType operator_map: [(<class 'str'>, <diffsynth.trainers.stuttgart_dataset.DataProcessingPipeline object at 0x14ad900ba6e0>)]
MAX: 624.5 min: -0.14001465
Image.shape (1080, 1920, 3) max value: 16.0
tensor shape: torch.Size([3, 1080, 1920]) max value: tensor(16.)
MAX: 690.0 min: -0.121032715
Image.shape (1080, 1920, 3) max value: 16.0
tensor shape: torch.Size([3, 1080, 1920]) max value: tensor(16.)


MAX: 659.5 min: -0.13256836
Image.shape (1080, 1920, 3) max value: 16.0
tensor shape: torch.Size([3, 1080, 1920]) max value: tensor(16.)
MAX: 629.0 min: -0.09844971
Image.shape (1080, 1920, 3) max value: 16.0
tensor shape: torch.Size([3, 1080, 1920]) max value: tensor(16.)


In [164]:
#read in mp4 video
# mp4_path = "/data2/saikiran.tedla/hdrvideo/diff/stuttgart_input.mp4"
# frames = VideoData(mp4_path, height=480, width=832).raw_data()
#input_video = pipe.preprocess_video(frames)

def encode_and_decode(video):
    with torch.no_grad():
        video = video * 2.0 - 1.0  # scale to [-1, 1]
        video = video.permute(0,2,1,3,4)  # (B, T, C, H, W) to (B, C, T, H, W)
        input_latents = pipe.vae.encode(video, device=pipe.device, tiled=False).to(dtype=pipe.torch_dtype, device=pipe.device)
        output_video = pipe.vae.decode(input_latents, device=pipe.device, tiled=False).to(dtype=torch.float32, device="cpu")
        output_video = (output_video + 1.0) / 2.0  # scale to [0, 1]
        output_video = torch.clamp(output_video, 0.0, 1.0) #clip output video to [0,1]
        output_video = output_video.permute(0,2,1,3,4)  # (B, C, T, H, W) to (B, T, C, H, W)
    return output_video

decoded_bracket_video = encode_and_decode(bracket_video)
decoded_normal_exposure = encode_and_decode(normal_exposure_padded)
decoded_low_exposure = encode_and_decode(low_exposure_padded)
decoded_high_exposure = encode_and_decode(high_exposure_padded)
decoded_bracket_with_pad = encode_and_decode(bracket_with_pad)

#combine the decoded exposures to get hdr video
decoded_combined = torch.cat([decoded_normal_exposure[:, 1:5], decoded_low_exposure[:, 1:5],  decoded_high_exposure[:, 1:5]], dim=1)
#remove padding
decoded_bracket_with_pad = torch.cat([decoded_bracket_with_pad[:, 1:5], decoded_bracket_with_pad[:, 9:13], decoded_bracket_with_pad[:, 16:20]], dim=1)


#encode and decode the hdr video
max_hdr_value = torch.max(hdr_video)
decoded_hdr_video_padded = encode_and_decode(hdr_video_padded/max_hdr_value) * max_hdr_value.cpu()
decoded_hdr_video = decoded_hdr_video_padded[:, 1:]

print("decoded_bracket_video[0][1:] shape:", decoded_bracket_video[0][1:].shape)
bracket_to_hdr_video = process_bracketed_video(decoded_bracket_video[0][1:]).unsqueeze(0)# with decoded
seperate_bracket_to_hdr_video = process_bracketed_video(decoded_combined[0].to(dtype=torch.float32, device="cpu")).unsqueeze(0) # with original bracket video
bracket_w_pad_video = process_bracketed_video(decoded_bracket_with_pad[0].to(dtype=torch.float32, device="cpu")).unsqueeze(0) # with original bracket video


max_value = torch.max(hdr_video).item()
print("Using max: ", max_value)
psnr_bracket = average_frame_psnr(decoded_bracket_video, bracket_video, data_range=max_value)
psnr_hdr = average_frame_psnr(decoded_hdr_video, hdr_video, data_range=max_value)
psnr_bracket_to_hdr = average_frame_psnr(bracket_to_hdr_video, hdr_video, data_range=max_value)
seperate_psnr_bracket_to_hdr = average_frame_psnr(seperate_bracket_to_hdr_video, hdr_video, data_range=max_value)
psnr_bracket_w_pad = average_frame_psnr(bracket_w_pad_video, hdr_video, data_range=max_value)

print(f"PSNR between decoded bracket video and original bracket video: {psnr_bracket.item():.2f} dB")
print(f"PSNR between decoded HDR video and original HDR video: {psnr_hdr.item():.2f} dB")
print(f"PSNR between bracket-to-HDR video and original HDR video: {psnr_bracket_to_hdr.item():.2f} dB")
print(f"PSNR between seperate bracket-to-HDR video and original HDR video: {seperate_psnr_bracket_to_hdr.item():.2f} dB")
print(f"PSNR between bracket with padding to HDR video and original HDR video: {psnr_bracket_w_pad.item():.2f} dB")
# Save the videos for visual inspection
output_hdr_video(decoded_bracket_video[0], "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/decoded_bracket_video")
output_hdr_video(decoded_hdr_video[0], "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/decoded_hdr_video")
output_hdr_video(bracket_video[0].to(dtype=torch.float32, device="cpu"), "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/original_bracket_video")
output_hdr_video(hdr_video[0].to(dtype=torch.float32, device="cpu"), "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/original_hdr_video")
output_hdr_video(bracket_to_hdr_video[0], "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/bracket_to_hdr_video")
output_hdr_video(seperate_bracket_to_hdr_video[0], "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/seperate_bracket_to_hdr_video")
output_hdr_video(bracket_w_pad_video[0], "/data2/saikiran.tedla/hdrvideo/diff/encoder_test/bracket_with_pad_to_hdr_video")

decoded_bracket_video[0][1:] shape: torch.Size([12, 3, 480, 832])
Using max:  16.0
PSNR between decoded bracket video and original bracket video: 52.28 dB
PSNR between decoded HDR video and original HDR video: 30.42 dB
PSNR between bracket-to-HDR video and original HDR video: 27.48 dB
PSNR between seperate bracket-to-HDR video and original HDR video: 30.43 dB
PSNR between bracket with padding to HDR video and original HDR video: 29.86 dB
