In [4]:
import sys
import torch
sys.path.append('/home/nthadishetty1/frame_interpollation/CameraCtrl')

from CameraCtrl.cameractrl.models.pose_adaptor import CameraPoseEncoder, PoseAdaptor
from data import StereoEventDataset
lucker_embeddings = StereoEventDataset.plucker_embeddings


train_dataset = StereoEventDataset(video_data_dir="/home/nthadishetty1/frame_interpollation/stereo_vkitti_folders",frame_height=375,frame_width=375)
train_dataloader = torch.utils.data.DataLoader(train_dataset,shuffle=None,collate_fn=None,batch_size=1, num_workers=1)

for batch_idx, batch in enumerate(train_dataloader):
    if batch_idx >= 3:
        break
    video_name = batch['video_name'][0]     
    left_data = batch['left']
    right_data = batch['right']
    left_meta_path = batch['left']['metadata_path']
    right_meta_path = batch['right']['metadata_path']


Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /pytorch/aten/src/ATen/native/Cross.cpp:63.)
  rays_dxo = torch.cross(rays_o, rays_d)                          # B, V, HW, 3


KeyError: 'metadata_path'

In [6]:
import numpy as np 
k =np.load('/data/venkateswara_lab/frame_interpollation/bs_ergb/3_TRAINING/acquarium_02/events/000000.npz')
print(k)

NpzFile '/data/venkateswara_lab/frame_interpollation/bs_ergb/3_TRAINING/acquarium_02/events/000000.npz' with keys: x, y, timestamp, polarity


In [1]:
import os
from glob import glob
import random
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from natsort import natsorted
import cv2
import time

# ------------------- Helper Functions -------------------

def mask_function(event_image, kernel_size=31, kernel_size_erode=61,kernel_size_midele=31, iterations=1, sigma_log=10):
    max_value = np.max(np.abs(event_image))
    if max_value != 0:
        event_image = np.abs(event_image) / max_value
    else:
        event_image = np.abs(event_image)

    event_image_blurred = cv2.GaussianBlur(event_image, (kernel_size,kernel_size), sigma_log)
    _, binary_image = cv2.threshold(event_image_blurred, 0.01, 1, cv2.THRESH_BINARY)
    kernel_dilate = np.ones((kernel_size_erode, kernel_size_erode), np.uint8)
    binary_image_dilated = cv2.dilate(binary_image, kernel_dilate, iterations=iterations)
    binary_median = cv2.medianBlur(binary_image_dilated.astype(np.uint8), kernel_size_midele)
    return binary_median

def save_debug_images_as_rgb(debug_dir, event_image, cumulative_image, B):
    os.makedirs(debug_dir, exist_ok=True)
    for b in range(B):
        event_channel = event_image[:, :, b]
        cumulative_channel = cumulative_image[:, :, b]
        event_channel_rgb = np.clip(event_channel * 255, 0, 255).astype(np.uint8)
        cumulative_channel_rgb = np.clip(cumulative_channel * 255, 0, 255).astype(np.uint8)
        cv2.imwrite(os.path.join(debug_dir, f"event_image_channel_{b}.png"), event_channel_rgb)
        cv2.imwrite(os.path.join(debug_dir, f"cumulative_image_channel_{b}.png"), cumulative_channel_rgb)

def create_event_image(args, x, y, p, t, shape, B=6, debug=False):
    height, width = shape[:2]
    event_image = np.zeros((height, width, B), dtype=np.float32)
    cumulative_image = np.zeros((height, width, B), dtype=np.float32)
    
    start_time, end_time = t[0], t[-1]
    delta_T = end_time - start_time
    normalized_timestamps = (B - 1) * (t - start_time) / delta_T
    
    x = np.clip(x.astype(int), 0, width - 1)
    y = np.clip(y.astype(int), 0, height - 1)

    bin_idx = np.round(normalized_timestamps).astype(int)
    bin_idx = np.clip(bin_idx, 0, B - 1)

    weights = np.maximum(0, 1 - np.abs(normalized_timestamps - bin_idx))
    np.add.at(event_image, (y, x, bin_idx), p * weights)

    norm_value_evs = np.maximum(np.abs(np.min(event_image)), np.max(event_image))
    event_image = (event_image + norm_value_evs)/ (2 * norm_value_evs)

    if args.event_filter == "great_filter":
        for i in range(0, B, 3):
            mask_group = []
            for j in range(i, min(i + 3, B)):
                mask = (bin_idx <= j) & (bin_idx >= np.maximum(0, j - 2))
                cumulative_image_f = cumulative_image[:, :, j].copy()
                np.add.at(cumulative_image_f, (y[mask], x[mask]), np.abs(p[mask]))
                motion_mask = mask_function(cumulative_image_f)
                mask_group.append(motion_mask)
            combined_mask = np.logical_or.reduce(mask_group)
            for j in range(i, min(i + 3, B)):
                cumulative_image[:, :, j] = combined_mask.astype(np.uint8)
        if debug:
            save_debug_images_as_rgb("./debug_output", event_image, cumulative_image, B)
        return event_image * cumulative_image
    else:
        return event_image

# ------------------- Dataset Class -------------------

class StableVideoDataset(Dataset):
    def __init__(self, args, video_data_dir, max_num_videos=None, frame_height=576, frame_width=1024, num_frames=14,
                 is_reverse_video=True, random_seed=42, skip_sampling_rate=1):
        
        self.video_data_dir = video_data_dir
        video_names = sorted([video for video in os.listdir(video_data_dir) 
                              if os.path.isdir(os.path.join(video_data_dir, video))])
        self.length = min(len(video_names), max_num_videos) if max_num_videos else len(video_names)
        self.video_names = video_names[:self.length]

        self.skip_sampling_rate = skip_sampling_rate
        if skip_sampling_rate < 1:
            self.skip_sampling_rate = random.choices([1, 2, 3], weights=[0.5, 0.3, 0.2], k=1)[0]

        self.sample_frames = num_frames * self.skip_sampling_rate + 1 - self.skip_sampling_rate
        self.sample_stride = self.skip_sampling_rate
        print("skip/sample_frames/stride:", self.skip_sampling_rate, self.sample_frames, self.sample_stride)

        self.event_nums_between = num_frames
        self.frame_width = frame_width
        self.frame_height = frame_height
        self.pixel_transforms = transforms.Compose([
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        self.is_reverse_video = is_reverse_video
        self.args = args
        np.random.seed(random_seed)

    def load_npz(self, file_path):
        data = np.load(file_path, allow_pickle=True)
        return data['x'], data['y'], data['polairty'], data['time']

    def accumulate_events(self, video_frame_paths, rgb_frame_paths, B=6):
        x_list, y_list, p_list, t_list = [], [], [], []
        video_frames = []
        shape = None

        for event_path, rgb_path in zip(video_frame_paths, rgb_frame_paths):
            x, y, p, t = self.load_npz(event_path)
            x_list.append(x); y_list.append(y); p_list.append(p); t_list.append(t)

            frame = np.array(Image.open(rgb_path)).astype(np.float32)/255.0
            video_frames.append(frame)
            shape = frame.shape[:2]

        x_all = np.concatenate(x_list)
        y_all = np.concatenate(y_list)
        p_all = np.concatenate(p_list)
        t_all = np.concatenate(t_list)

        event_voxel_bin = create_event_image(self.args, x_all, y_all, p_all, t_all, shape, B=B)
        pixel_values = torch.from_numpy(np.stack(video_frames, axis=0).transpose(0, 3, 1, 2))
        event_voxel_bin = torch.from_numpy(event_voxel_bin.transpose(2, 0, 1)).unsqueeze(0).repeat(len(video_frames), 1, 1, 1)

        return pixel_values, event_voxel_bin

    def get_batch(self, idx):
        video_name = self.video_names[idx]
        video_frame_paths = natsorted(glob(os.path.join(self.video_data_dir, video_name, 'events','*.npz')))
        rgb_frame_paths   = natsorted(glob(os.path.join(self.video_data_dir, video_name, 'images','*.png')))

        start_idx = np.random.randint(0, len(video_frame_paths)-self.sample_frames+1)
        video_frame_paths = video_frame_paths[start_idx:start_idx+self.sample_frames]
        rgb_frame_paths   = rgb_frame_paths[start_idx:start_idx+self.sample_frames]

        return self.accumulate_events(video_frame_paths, rgb_frame_paths, B=6)

    def crop_center_patch(self, pixel_values, event_voxel_bin, crop_h=512, crop_w=512, random_crop=False):
        H, W = pixel_values.shape[2:4]
        if random_crop:
            start_h = random.randint(0, H - crop_h)
            start_w = random.randint(0, W - crop_w)
        else:
            start_h = H//2 - crop_h//2
            start_w = W//2 - crop_w//2
        return (pixel_values[:, :, start_h:start_h+crop_h, start_w:start_w+crop_w],
                event_voxel_bin[:, :, start_h:start_h+crop_h, start_w:start_w+crop_w])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                pixel_values, event_voxel_bin = self.get_batch(idx)
                break
            except Exception as e:
                idx = random.randint(0, self.length-1)

        pixel_values = self.pixel_transforms(pixel_values)
        event_voxel_bin = event_voxel_bin  # events may not need normalization

        pixel_values, event_voxel_bin = self.crop_center_patch(pixel_values, event_voxel_bin, random_crop=True)
        conditions = pixel_values[-1]

        sample = dict(
            pixel_values=pixel_values,       # (T, 3, H, W)
            event_voxel_bin=event_voxel_bin, # (T, 6, H, W)
            conditions=conditions            # last RGB frame
        )
        return sample


In [None]:
import torch
from torch.utils.data import DataLoader
from types import SimpleNamespace  # Used to mock the 'args' object
import warnings
warnings.filterwarnings("ignore")

# --- Import your StableVideoDataset class here ---
# from your_dataset_file import StableVideoDataset

# 1. Create a mock 'args' object
mock_args = SimpleNamespace(event_filter="great_filter")

# 2. Define your data directory
data_directory = "/data/venkateswara_lab/frame_interpollation/bs_ergb/3_TRAINING/"

# 3. Create an instance of the dataset
print("Initializing dataset...")
try:
    dataset = StableVideoDataset(
        args=mock_args,
        video_data_dir=data_directory,
        frame_height=576,  # ✅ corrected name
        frame_width=1024,
        num_frames=5,
        skip_sampling_rate=1
    )
    print(f"✅ Dataset initialized successfully. Found {len(dataset)} videos in '{data_directory}'.")

    # --- Optional: Test loading the first item ---
    print("\n--- Testing __getitem__ ---")
    sample = dataset[0]
    print("✅ Successfully loaded sample 0.")
    print(f"  pixel_values shape: {sample['pixel_values'].shape}")      # (T, 3, H, W)
    print(f"  event_voxel_bin shape: {sample['event_voxel_bin'].shape}")  # (T, 6, H, W)

    # Display a few tensor values up to 4 decimals
    print("\n--- Sample Tensor Values ---")
    print("pixel_values[:,:,:4,:4]:\n", torch.round(sample['pixel_values'][0, :, :4, :4] * 10000) / 10000)
    print("event_voxel_bin[:,:,:4,:4]:\n", torch.round(sample['event_voxel_bin'][0, 0, :4, :4] * 10000) / 10000)

except FileNotFoundError:
    print(f"❌ Error: Data directory not found at '{data_directory}'.")
except Exception as e:
    print(f"❌ An error occurred during dataset initialization or testing:\n{e}")


Initializing dataset...
skip/sample_frames/stride: 1 50 1
✅ Dataset initialized successfully. Found 47 videos in '/data/venkateswara_lab/frame_interpollation/bs_ergb/3_TRAINING/'.

--- Testing __getitem__ ---


In [None]:
import os
import torch
import argparse
import copy
from diffusers.utils import load_image, export_to_video
from diffusers import UNetSpatioTemporalConditionModel
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.pipelines.evs_pipeline_frame_interpolation_with_noise_injection_color import EVSFrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from attn_ctrl.attention_control import (AttentionStore, 
                                         register_temporal_self_attention_control, 
                                         register_temporal_self_attention_flip_control,
)
from dataset.stable_video_dataset import StableVideoDataset,StableVideoTestDataset
from torch.utils.data import DataLoader
from einops import rearrange

import numpy as np
import cv2

from PIL import Image
import torch


def tensor_to_pillow(tensor, save_path):
    # Squeeze the batch dimension if exists and convert to numpy
    # print("val:",torch.max(tensor),torch.min(tensor))
    image_data = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()/2.0 + 0.5
    image_data = image_data /np.max(image_data) * 255
    image_data = image_data.astype("uint8")
    # Create a PIL image
    pil_image = Image.fromarray(image_data)
    # Save the image
    pil_image.save(save_path)
    return pil_image


def main(args):

    noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    # pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
    #     args.pretrained_model_name_or_path, 
    #     scheduler=noise_scheduler,
    #     variant="fp16",
    #     torch_dtype=torch.float16, 
    # )
    pipe = EVSFrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
        args.pretrained_model_name_or_path, 
        # args.checkpoint_dir,
        scheduler=noise_scheduler,
        variant="fp16",
        torch_dtype=torch.float16, 
    )
    
    
    # ref_unet = pipe.ori_unet
    
    # state_dict = pipe.unet.state_dict()
    # # computing delta w
    finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
        args.checkpoint_dir,
        subfolder="unet",
        torch_dtype=torch.float16,
    ) 
    # # assert finetuned_unet.config.num_frames==14
    # ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
    #     "/mnt/workspace/zhangziran/DiffEVS/svd_keyframe_interpolation-main/checkpoints/stable-video-diffusion-img2vid",
    #     subfolder="unet",
    #     variant='fp16',
    #     torch_dtype=torch.float16,
    # )

    # # print("-----"*10)

    finetuned_state_dict = finetuned_unet.state_dict()
    # ori_state_dict = ori_unet.state_dict()
    # for name, param in finetuned_state_dict.items():
    #     if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
    #         delta_w = param - ori_state_dict[name]
    #         state_dict[name] = state_dict[name] + delta_w
    # pipe.unet.load_state_dict(state_dict)
    
    pipe.unet.load_state_dict(finetuned_state_dict)

    # # controller_ref= AttentionStore()
    # # register_temporal_self_attention_control(ref_unet, controller_ref)

    # # controller = AttentionStore()
    # # register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
    
    del finetuned_unet
    # del ori_unet

    pipe = pipe.to(args.device)


    print("-----"*10)
    
    # run inference
    generator = torch.Generator(device=args.device)
    if args.seed is not None:
        generator = generator.manual_seed(args.seed)
        
    dataset = StableVideoTestDataset(args,args.frames_dirs,num_frames=pipe.unet.config.num_frames,skip_sampling_rate=args.skip_sampling_rate)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
    print("-----"*10)
    for i, batch in enumerate(dataloader):
        print(i)
    
        evs = batch["event_voxel_bin"]
        frame2 = batch["conditions"]
        frame1 = batch["pixel_values"][:, 0]
        save_name = batch["save_name"]
        print(save_name)
        print(frame1.shape,frame2.shape,evs.shape)
        
        
        # frame11 = load_image(args.frame1_path)
        # # frame1 = frame1.resize((854, 640))
        # frame11 = frame11.resize((1024, 576))
        

        # frame22 = load_image(args.frame2_path)
        # frame22 = frame22.resize((1024, 576))
        # # # print("************"*20,frame11.shape,np.mean(frame11))
        # from torchvision import transforms
        # print("************" * 20, (tensor := transforms.ToTensor()(frame11)).shape, tensor.mean(),frame2.shape)
        
        # frame2 = frame2.resize((854, 640))
        # frame2 = frame2.resize((1024, 576))

        frames = pipe(image1=frame1, image2=frame2, evs=evs, height=frame1.shape[-2], width=frame1.shape[-1],
                    num_inference_steps=args.num_inference_steps, 
                    generator=generator,
                    weighted_average=args.weighted_average,
                    noise_injection_steps=args.noise_injection_steps,
                    noise_injection_ratio= args.noise_injection_ratio,
        ).frames[0]
        save_path = args.out_path
        save_path = save_path.replace("example",save_name[0])
        
        if save_path.endswith('.gif'):
            frames[0].save(save_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
        else:
            export_to_video(frames, save_path, fps=7)
        
        from eval_function import calculate_metrics
        # 假设要保存txt结果
        txt_path = save_path.replace(".gif", "_metrics.txt")
        with open(txt_path, "w") as f:  # 使用'w'模式创建新文件
            f.write(f"Processing: {save_path}\n")
            
            # 创建用于存储各种指标值的字典
            metrics_values = {
                'lpips': [],
                'ssim': [],
                'psnr': [],
                'maniqa': [],
                'musiq': [],
                'liqe': []
            }

            # 遍历帧，排除初始帧和最后一帧进行评价
            for i, frame in enumerate(frames):
                if i == 0 or i == len(frames) - 1:  # 排除第一帧和最后一帧
                    frame.save(save_path.replace(".gif", f"_{i+1}s.png"))
                    frame_tensor = batch["pixel_values"][:, i]
                    frame_pillow = tensor_to_pillow(frame_tensor, save_path.replace(".gif", f"_{i+1}.png"))
                    print(save_path.replace(".gif", f"_{i+1}.png"), len(frames))
                    continue
                
                frame.save(save_path.replace(".gif", f"_{i+1}s.png"))
                frame_tensor = batch["pixel_values"][:, i]
                frame_pillow = tensor_to_pillow(frame_tensor, save_path.replace(".gif", f"_{i+1}.png"))
                print(save_path.replace(".gif", f"_{i+1}.png"), len(frames))
                
                # 获取frame和frame_tensor一一对应的图像并计算评价指标
                image1 = np.array(frame)  # frame 是 PIL 图像，转换为 numpy 数组
                image2 = np.array(frame_pillow) #frame_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()  # 转换 tensor 为 numpy 数组
                
                print(image1.shape, image2.shape, np.min(image1), np.max(image1), np.min(image2), np.max(image2))

                # 计算所有指标
                results = calculate_metrics(
                    image1, image2, 
                    loss_dict_lpips={'as_loss': False, 'weight': 1.0},
                    loss_dict_ssim={'weight': 1.0}, 
                    loss_dict_psnr={'as_loss': False, 'weight': 1.0},
                    loss_dict_maniqa={'as_loss': False, 'weight': 1.0},
                    loss_dict_musiq={'as_loss': False, 'weight': 1.0},
                    loss_dict_liqe={'as_loss': False, 'weight': 1.0}
                )

                # 将当前帧的结果追加到对应的列表中
                for metric_name, value in results.items():
                    if value is not None:
                        metrics_values[metric_name].append(value.cpu().numpy())

            # 计算平均指标
            avg_metrics = {}
            for metric_name, values in metrics_values.items():
                if values:  # 确保列表不为空
                    avg_metrics[metric_name] = np.mean(values)
                else:
                    avg_metrics[metric_name] = None

            # # 写入平均指标到txt
            # f.write("Average Metrics:\n")
            # for metric_name, avg_value in avg_metrics.items():
            #     if avg_value is not None:
            #         f.write(f"Average {metric_name.upper()}: {avg_value:.6f}\n")

            # 写入平均指标到txt（所有指标在一行中）
            f.write("Average Metrics: ")
            metric_strings = []
            for metric_name, avg_value in avg_metrics.items():
                if avg_value is not None:
                    metric_strings.append(f"{metric_name.upper()}: {avg_value:.6f}")
            f.write(", ".join(metric_strings) + "\n\n")  # 将所有指标用逗号分隔写在一行，并添加空行
            
            f.write("\n")  # 在每个视频的结果后加一个空行


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model_name_or_path", type=str, default="stabilityai/stable-video-diffusion-img2vid-xt")
    parser.add_argument("--checkpoint_dir", type=str, required=True)
    # parser.add_argument('--frame1_path', type=str, required=True)
    # parser.add_argument('--frame2_path', type=str, required=True)
    parser.add_argument('--out_path', type=str, required=True)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_inference_steps', type=int, default=50)
    parser.add_argument('--weighted_average', action='store_true')
    parser.add_argument('--noise_injection_steps', type=int, default=0)
    parser.add_argument('--noise_injection_ratio', type=float, default=0.5)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--frames_dirs', type=str, default='cuda:0')
    parser.add_argument('--event_filter', type=str, default=None)
    parser.add_argument('--skip_sampling_rate', type=int, default=1)
    args = parser.parse_args()
    out_dir = os.path.dirname(args.out_path)
    os.makedirs(out_dir, exist_ok=True)
    main(args)

In [3]:

from data import StereoEventDataset
from torch.utils.data import DataLoader

class StereoEventTestDataset(StereoEventDataset):
    def __init__(self, video_data_dir, frame_height=375, frame_width=375):
        super().__init__(video_data_dir, frame_height, frame_width)
        self.video_names = [self.video_names[0]]
        self.length = 1  # Only one video for inference

    def __getitem__(self, idx):
        video_name = self.video_names[idx]
        paths = self._get_paths(video_name)

        left_rgb = self._load_rgb(paths['left']['rgb'])
        left_event = self._load_events(paths['left']['event'])
        right_rgb = self._load_rgb(paths['right']['rgb'])
        right_event = self._load_events(paths['right']['event'])

        # Apply transforms
        def apply_transform_to_sequence(sequence_tensor, transform_fn):
            if sequence_tensor.ndim == 3:
                return transform_fn(sequence_tensor)
            transformed_frames = [transform_fn(sequence_tensor[t]) for t in range(sequence_tensor.shape[0])]
            return torch.stack(transformed_frames, dim=0)

        left_rgb = apply_transform_to_sequence(left_rgb, self.transform_rgb)
        right_rgb = apply_transform_to_sequence(right_rgb, self.transform_rgb)
        left_event = apply_transform_to_sequence(left_event, self.transforms_evs)
        right_event = apply_transform_to_sequence(right_event, self.transforms_evs)

        # Use center crop (no random cropping for inference)
        left_pixel_values, left_events = self.crop_center_patch(left_rgb, left_event, random_crop=False)
        right_pixel_values, right_events = self.crop_center_patch(right_rgb, right_event, random_crop=False)

        return dict(
            left=dict(pixel_values=left_pixel_values, events=left_events),
            right=dict(pixel_values=right_pixel_values, events=right_events),
            video_name=video_name)
    
dataset = StereoEventTestDataset("/data/venkateswara_lab/frame_interpollation/data")
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
print("-----"*10)
# for i, batch in enumerate(dataloader):
#     print(i)

#     evs = batch["event_voxel_bin"]
#     frame2 = batch["conditions"][-1]
#     frame1 = batch["pixel_values"][:, 0]
#     print(frame1.shape,frame2.shape,evs.shape)


--------------------------------------------------


In [1]:
from data import StereoEventDataset
from torch.utils.data import DataLoader
import torch
import torchvision.transforms as T
import cv2
import numpy as np

class StereoEventTestDataset(StereoEventDataset):
    def __init__(self, video_data_dir, frame_height=375, frame_width=375):
        super().__init__(video_data_dir, frame_height, frame_width)
        self.video_names = [self.video_names[0]]  # take only first video
        self.length = 1
        frame_height = 576
        frame_width = 1024

    def __getitem__(self, idx):
        video_name = self.video_names[idx]
        paths = self._get_paths(video_name)
        resize_transform = T.Compose([
    T.Resize((frame_height, frame_width)),  # height, width
])

        # Load left and right data
        left_rgb = self._load_rgb(paths['left']['rgb'])
        left_event = self._load_events(paths['left']['event'])
        right_rgb = self._load_rgb(paths['right']['rgb'])
        right_event = self._load_events(paths['right']['event'])

        # Helper for frame-wise transform
        def apply_transform_to_sequence(sequence_tensor, transform_fn):
            if sequence_tensor.ndim == 3:
                return transform_fn(sequence_tensor)
            transformed_frames = [transform_fn(sequence_tensor[t]) for t in range(sequence_tensor.shape[0])]
            return torch.stack(transformed_frames, dim=0)

        left_rgb = apply_transform_to_sequence(left_rgb, self.transform_rgb)
        right_rgb = apply_transform_to_sequence(right_rgb, self.transform_rgb)
        left_event = apply_transform_to_sequence(left_event, self.transforms_evs)
        right_event = apply_transform_to_sequence(right_event, self.transforms_evs)

        # Center crop for inference
        left_pixel_values, left_events = self.crop_center_patch(left_rgb, left_event, random_crop=False)
        right_pixel_values, right_events = self.crop_center_patch(right_rgb, right_event, random_crop=False)

        # Select specific frames/events
        frame1 = left_pixel_values[0]             # first frame of left RGB
        frame2 = left_pixel_values[-1]            # last frame of left RGB
        evs = right_events                        # right events

        frame1_resized = resize_transform(frame1.squeeze(0))
        frame2_resized = resize_transform(frame2.squeeze(0))
        return dict(
            frame1=frame1_resized,
            frame2=frame2_resized,
            evs=evs,
            video_name=video_name
        )


frame_height = 576
frame_width = 1024

# Define resize transform
resize_transform = T.Compose([
    T.Resize((frame_height, frame_width)),  # height, width
])

dataset = StereoEventTestDataset("/data/venkateswara_lab/frame_interpollation/data")
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

print("-----" * 10)

for i, batch in enumerate(dataloader):
    frame1 = batch['frame1']  # Shape: [B, C, H, W]
    frame2 = batch['frame2']  # Shape: [B, C, H, W]
    evs = batch['evs']
    video_name = batch['video_name']
    
    print(f"Batch {i}:")
    print(f"  Video name: {video_name}")
    print(f"  Frame1 shape: {frame1.shape}")
    print(f"  Frame2 shape: {frame2.shape}")
    print(f"  Events shape: {evs.shape}")

--------------------------------------------------
Batch 0:
  Video name: ['stereo_vkitti200000000']
  Frame1 shape: torch.Size([1, 3, 576, 1024])
  Frame2 shape: torch.Size([1, 3, 576, 1024])
  Events shape: torch.Size([1, 17, 6, 375, 375])


In [7]:
import os
import torch
import argparse
import copy
from diffusers.utils import load_image, export_to_video
from diffusers import UNetSpatioTemporalConditionModel
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.pipelines.evs_pipeline_frame_interpolation_with_noise_injection_color import EVSFrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from attn_ctrl.attention_control import (AttentionStore, 
                                         register_temporal_self_attention_control, 
                                         register_temporal_self_attention_flip_control)

from torch.utils.data import DataLoader
from einops import rearrange

import numpy as np
import cv2

from PIL import Image
import torch


def tensor_to_pillow(tensor, save_path):
    image_data = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()/2.0 + 0.5
    image_data = image_data /np.max(image_data) * 255
    image_data = image_data.astype("uint8")
    pil_image = Image.fromarray(image_data)
    pil_image.save(save_path)
    return pil_image

def main(args):

    noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    pipe = EVSFrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
        args.pretrained_model_name_or_path, 
        # args.checkpoint_dir,
        scheduler=noise_scheduler,
        variant="fp16",
        torch_dtype=torch.float16, 
    )
    
    
    # ref_unet = pipe.ori_unet
    
    # state_dict = pipe.unet.state_dict()
    # # computing delta w
    finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
        args.checkpoint_dir,
        subfolder="unet",
        torch_dtype=torch.float16,
    ) 
    # # assert finetuned_unet.config.num_frames==14
    # ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
    #     "/mnt/workspace/zhangziran/DiffEVS/svd_keyframe_interpolation-main/checkpoints/stable-video-diffusion-img2vid",
    #     subfolder="unet",
    #     variant='fp16',
    #     torch_dtype=torch.float16,
    # )

    # # print("-----"*10)

    finetuned_state_dict = finetuned_unet.state_dict()
    # ori_state_dict = ori_unet.state_dict()
    # for name, param in finetuned_state_dict.items():
    #     if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
    #         delta_w = param - ori_state_dict[name]
    #         state_dict[name] = state_dict[name] + delta_w
    # pipe.unet.load_state_dict(state_dict)
    
    pipe.unet.load_state_dict(finetuned_state_dict)

    # # controller_ref= AttentionStore()
    # # register_temporal_self_attention_control(ref_unet, controller_ref)

    # # controller = AttentionStore()
    # # register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
    
    del finetuned_unet
    # del ori_unet

    pipe = pipe.to(args.device)
    print("-----"*10)
    generator = torch.Generator(device=args.device)
    if args.seed is not None:
        generator = generator.manual_seed(args.seed)
        

    dataset = StereoEventTestDataset("/data/venkateswara_lab/frame_interpollation/data")
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
    i=1
    print("-----" * 10)
    for i, batch in enumerate(dataloader):
        print(f"Video {i}: {batch['video_name']}")
        print(type(batch["frame1"]))
        print("Frame1 shape:", batch["frame1"].shape)
        print("Frame2 shape:", batch["frame2"].shape)
        print("Events shape:", batch["evs"].shape)
        
        frames = pipe(image1=batch["frame1"], image2=batch["frame1"], evs=batch["evs"], height=batch["frame1"].shape[-2], width=batch["frame1"].shape[-1],
                    num_inference_steps=args.num_inference_steps, 
                    generator=generator,
                    weighted_average=args.weighted_average,
                    noise_injection_steps=args.noise_injection_steps,
                    noise_injection_ratio= args.noise_injection_ratio,
        ).frames[0]
        save_path = args.out_path

        export_to_video(frames, save_path, fps=7)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model_name_or_path", type=str, default="stabilityai/stable-video-diffusion-img2vid")
    parser.add_argument("--checkpoint_dir", type=str, default ='/data/venkateswara_lab/frame_interpollation/trained_models_full_recent/checkpoint-9500')
    # parser.add_argument('--frame1_path', type=str, required=True)
    # parser.add_argument('--frame2_path', type=str, required=True)
    parser.add_argument('--out_path', type=str, default="/data/venkateswara_lab/frame_interpollation/code/examples")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_inference_steps', type=int, default=50)
    parser.add_argument('--weighted_average', action='store_true')
    parser.add_argument('--noise_injection_steps', type=int, default=0)
    parser.add_argument('--noise_injection_ratio', type=float, default=0.5)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--frames_dirs', type=str, default='cuda:0')
    parser.add_argument('--event_filter', type=str, default=None)
    parser.add_argument('--skip_sampling_rate', type=int, default=1)
    args = parser.parse_args()
    out_dir = os.path.dirname(args.out_path)
    os.makedirs(out_dir, exist_ok=True)
    main(args)

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  9.42it/s]
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


--------------------------------------------------
--------------------------------------------------
Video 0: ['stereo_vkitti200000000']
<class 'torch.Tensor'>
Frame1 shape: torch.Size([1, 3, 576, 1024])
Frame2 shape: torch.Size([1, 3, 576, 1024])
Events shape: torch.Size([1, 17, 6, 375, 375])


TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not AutoencoderKLOutput

In [None]:
import os
import torch
import argparse
import copy
from diffusers.utils import load_image, export_to_video
from diffusers import UNetSpatioTemporalConditionModel
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.pipelines.evs_pipeline_frame_interpolation_with_noise_injection_color import EVSFrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from attn_ctrl.attention_control import (AttentionStore, 
                                         register_temporal_self_attention_control, 
                                         register_temporal_self_attention_flip_control)

from torch.utils.data import DataLoader, Dataset
from einops import rearrange
from accelerate import Accelerator, DistributedDataParallelKwargs

import numpy as np
import cv2
from PIL import Image
import torch
import glob


class StereoEventTestDataset(Dataset):
    """Dataset class for stereo event data"""
    def __init__(self, data_dir):
        self.data_dir = data_dir
        # Find all video directories
        self.video_dirs = sorted(glob.glob(os.path.join(data_dir, "*")))
        
    def __len__(self):
        return len(self.video_dirs)
    
    def __getitem__(self, idx):
        video_dir = self.video_dirs[idx]
        video_name = os.path.basename(video_dir)
        
        # Load frame1
        frame1_path = os.path.join(video_dir, "frame1.png")
        frame1 = Image.open(frame1_path).convert("RGB")
        frame1 = torch.from_numpy(np.array(frame1)).permute(2, 0, 1).float() / 255.0
        frame1 = frame1 * 2.0 - 1.0  # Normalize to [-1, 1]
        frame1 = frame1.unsqueeze(0)  # Add batch dimension
        
        # Load frame2
        frame2_path = os.path.join(video_dir, "frame2.png")
        frame2 = Image.open(frame2_path).convert("RGB")
        frame2 = torch.from_numpy(np.array(frame2)).permute(2, 0, 1).float() / 255.0
        frame2 = frame2 * 2.0 - 1.0  # Normalize to [-1, 1]
        frame2 = frame2.unsqueeze(0)  # Add batch dimension
        
        # Load events
        evs_path = os.path.join(video_dir, "events.npy")
        evs = np.load(evs_path)
        evs = torch.from_numpy(evs).float()
        if evs.dim() == 3:
            evs = evs.unsqueeze(0)  # Add batch dimension if needed
        
        return {
            "frame1": frame1,
            "frame2": frame2,
            "evs": evs,
            "video_name": video_name
        }


def tensor_to_pillow(tensor, save_path):
    """Convert tensor to PIL image and save"""
    image_data = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() / 2.0 + 0.5
    image_data = image_data / np.max(image_data) * 255
    image_data = image_data.astype("uint8")
    pil_image = Image.fromarray(image_data)
    pil_image.save(save_path)
    return pil_image


def main(args):
    # Initialize Accelerator for multi-GPU
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(
        mixed_precision="fp16",
        kwargs_handlers=[ddp_kwargs]
    )
    
    # Only print on main process
    if accelerator.is_main_process:
        print(f"Using {accelerator.num_processes} GPUs for inference")
        print(f"Process index: {accelerator.process_index}")
        print("=" * 50)

    # Load scheduler
    noise_scheduler = EulerDiscreteScheduler.from_pretrained(
        args.pretrained_model_name_or_path, 
        subfolder="scheduler"
    )
    
    # Load pipeline
    if accelerator.is_main_process:
        print("Loading base pipeline...")
    
    pipe = EVSFrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
        args.pretrained_model_name_or_path, 
        scheduler=noise_scheduler,
        variant="fp16",
        torch_dtype=torch.float16, 
    )
    
    # Load finetuned UNet
    if accelerator.is_main_process:
        print(f"Loading finetuned UNet from {args.checkpoint_dir}...")
    
    finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
        args.checkpoint_dir,
        subfolder="unet",
        torch_dtype=torch.float16,
    ) 

    finetuned_state_dict = finetuned_unet.state_dict()
    pipe.unet.load_state_dict(finetuned_state_dict)
    
    del finetuned_unet
    torch.cuda.empty_cache()

    if accelerator.is_main_process:
        print("UNet loaded successfully")

    # Move pipeline components to device
    pipe = pipe.to(accelerator.device)
    
    # Prepare UNet with Accelerator for distributed inference
    pipe.unet = accelerator.prepare(pipe.unet)
    
    if accelerator.is_main_process:
        print("Model prepared for distributed inference")
        print("=" * 50)
    
    # Setup generator
    generator = torch.Generator(device=accelerator.device)
    if args.seed is not None:
        # Different seed for each GPU to add diversity
        generator = generator.manual_seed(args.seed + accelerator.process_index)
        
    # Load dataset
    if accelerator.is_main_process:
        print("Loading dataset...")
    
    dataset = StereoEventTestDataset(args.data_dir)
    
    if accelerator.is_main_process:
        print(f"Found {len(dataset)} videos")
    
    # Create dataloader
    dataloader = DataLoader(
        dataset, 
        batch_size=1, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    
    # Prepare dataloader with Accelerator (splits data across GPUs)
    dataloader = accelerator.prepare(dataloader)
    
    if accelerator.is_main_process:
        print(f"Processing {len(dataset)} samples across {accelerator.num_processes} GPUs")
        print(f"Each GPU will process approximately {len(dataset) // accelerator.num_processes} samples")
        print("=" * 50)
    
    # Process videos
    for i, batch in enumerate(dataloader):
        # Calculate global index
        global_idx = i * accelerator.num_processes + accelerator.process_index
        
        video_name = batch['video_name'][0] if isinstance(batch['video_name'], (list, tuple)) else batch['video_name']
        
        if accelerator.is_main_process or True:  # Print from all processes
            print(f"[GPU {accelerator.process_index}] Processing batch {i} - Video: {video_name}")
            print(f"[GPU {accelerator.process_index}] Frame1 shape: {batch['frame1'].shape}")
            print(f"[GPU {accelerator.process_index}] Frame2 shape: {batch['frame2'].shape}")
            print(f"[GPU {accelerator.process_index}] Events shape: {batch['evs'].shape}")
        
        # Move batch to device if not already there
        frame1 = batch["frame1"].to(accelerator.device)
        frame2 = batch["frame2"].to(accelerator.device)
        evs = batch["evs"].to(accelerator.device)
        
        # Run inference
        with torch.no_grad():
            try:
                output = pipe(
                    image1=frame1, 
                    image2=frame2,  # You had frame1 twice in original, assuming this should be frame2
                    evs=evs, 
                    height=frame1.shape[-2], 
                    width=frame1.shape[-1],
                    num_inference_steps=args.num_inference_steps, 
                    generator=generator,
                    weighted_average=args.weighted_average,
                    noise_injection_steps=args.noise_injection_steps,
                    noise_injection_ratio=args.noise_injection_ratio,
                )
                frames = output.frames[0]
                
                # Save output
                os.makedirs(args.out_path, exist_ok=True)
                save_path = os.path.join(
                    args.out_path, 
                    f"{video_name}_gpu{accelerator.process_index}.mp4"
                )
                
                export_to_video(frames, save_path, fps=args.fps)
                
                print(f"[GPU {accelerator.process_index}] Saved video to: {save_path}")
                
            except Exception as e:
                print(f"[GPU {accelerator.process_index}] Error processing {video_name}: {str(e)}")
                import traceback
                traceback.print_exc()
    
    # Wait for all processes to finish
    accelerator.wait_for_everyone()
    
    if accelerator.is_main_process:
        print("=" * 50)
        print("All videos processed successfully!")
        print(f"Output saved to: {args.out_path}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Multi-GPU Frame Interpolation Inference")
    
    # Model arguments
    parser.add_argument(
        "--pretrained_model_name_or_path", 
        type=str, 
        default="stabilityai/stable-video-diffusion-img2vid",
        help="Path to pretrained model"
    )
    parser.add_argument(
        "--checkpoint_dir", 
        type=str, 
        default='/data/venkateswara_lab/frame_interpollation/trained_models_full_recent/checkpoint-9500',
        help="Path to finetuned checkpoint"
    )
    
    # Data arguments
    parser.add_argument(
        '--data_dir',
        type=str,
        default="/data/venkateswara_lab/frame_interpollation/data",
        help="Directory containing input data"
    )
    parser.add_argument(
        '--out_path', 
        type=str, 
        default="/data/venkateswara_lab/frame_interpollation/code/examples",
        help="Output directory for generated videos"
    )
    
    # Inference arguments
    parser.add_argument(
        '--seed', 
        type=int, 
        default=42,
        help="Random seed"
    )
    parser.add_argument(
        '--num_inference_steps', 
        type=int, 
        default=50,
        help="Number of denoising steps"
    )
    parser.add_argument(
        '--fps',
        type=int,
        default=7,
        help="Output video FPS"
    )
    
    # Noise injection arguments
    parser.add_argument(
        '--weighted_average', 
        action='store_true',
        help="Use weighted average"
    )
    parser.add_argument(
        '--noise_injection_steps', 
        type=int, 
        default=0,
        help="Number of noise injection steps"
    )
    parser.add_argument(
        '--noise_injection_ratio', 
        type=float, 
        default=0.5,
        help="Noise injection ratio"
    )
    
    # Additional arguments
    parser.add_argument(
        '--event_filter', 
        type=str, 
        default=None,
        help="Event filter type"
    )
    parser.add_argument(
        '--skip_sampling_rate', 
        type=int, 
        default=1,
        help="Skip sampling rate"
    )
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.out_path, exist_ok=True)
    
    # Run main
    main(args)