In [1]:
import debugpy
import gc
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from typing import Optional, Tuple, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from diffusers.models.unets import UNetSpatioTemporalConditionModel
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
from diffusers.models.unets import UNetSpatioTemporalConditionModel
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion_with_controlnet import StableVideoDiffusionPipelineWithControlNet,SpatioTemporalControlNet, CustomConditioningNet, SpatioTemporalControlNetOutput

from diffusers.image_processor import VaeImageProcessor

from types import MethodType

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16
time =  torch.tensor(1).to(dtype=dtype)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
from typing import Optional, List

from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion_with_controlnet import StableVideoDiffusionPipelineWithControlNet,SpatioTemporalControlNet, CustomConditioningNet, SpatioTemporalControlNetOutput

import gc
from diffusers import DiffusionPipeline
torch.cuda.empty_cache()


import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import json


class DiffusionDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.transform = transforms.Compose([
            transforms.Resize((320, 512)),
            transforms.CenterCrop((320, 512)),
        ])
        self.image_processor = VaeImageProcessor(vae_scale_factor=8)

    def __len__(self):
        # Assuming each set of ground truths represents a separate sample
        return len(self.data['ground_truth'])

    def __getitem__(self, idx):
        # Processing ground truth images
        
        ground_truth_images = [self.transform(Image.open(path)) for path in self.data['ground_truth'][idx]]
        ground_truth_images = self.image_processor.preprocess(image = ground_truth_images, height = 320, width = 512)

        # Processing conditioning images set one (assuming RGB, 4 channels after conversion)
        conditioning_images_one = [self.transform(Image.open(path).convert("RGB")) for path in self.data['conditioning_images_one'][idx]]
        conditioning_images_one = self.image_processor.preprocess(image = conditioning_images_one, height = 320, width = 512)

        # Processing conditioning images set two (assuming grayscale, converted to RGB to match dimensions)
        # conditioning_images_two = [self.transform(Image.open(path)) for path in self.data['conditioning_images_two'][idx]]
        # conditioning_images_two = self.image_processor.preprocess(image = conditioning_images_two, height = 320, width = 512)
        
        # Concatenating condition one and two images along the channel dimension
        # conditioned_images = [torch.cat((img_one, img_two), dim=0) for img_one, img_two in zip(conditioning_images_one, conditioning_images_two)]

        # Processing reference images (single per scene, matched by index)
        reference_image = self.transform(Image.open(self.data['ground_truth'][idx][0]))

        # Retrieving the corresponding caption
        caption = self.data['caption'][idx][0]

        

        return {
            "ground_truth": ground_truth_images,
            "conditioning": conditioning_images_one,
            "caption": caption,
            "reference_image": reference_image
        }

def collate_fn(batch):
    ground_truth = torch.stack([item['ground_truth'] for item in batch])
    conditioning = torch.stack([item['conditioning'] for item in batch])
    captions = [item['caption'] for item in batch]  # List of strings, no need to stack
    reference_images = [item['reference_image'] for item in batch]
    

    return {
        "ground_truth": ground_truth.flatten(0, 1),
        "conditioning": conditioning.flatten(0, 1),
        "caption": captions[0],
        "reference_image": reference_images[0],
    }



train_dataset = DiffusionDataset(json_path='/home/wisley/custom_diffusers_library/src/diffusers/jasper/validation.json')

train_dataloader = DataLoader(
    train_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=1,  # Or your preferred batch size
    num_workers=0,  # Adjust based on your setup
)
import matplotlib.pyplot as plt
data_iter = iter(train_dataloader)
batch = next(data_iter)
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion_with_controlnet import StableVideoDiffusionPipelineWithControlNet,SpatioTemporalControlNet, CustomConditioningNet, SpatioTemporalControlNetOutput
prompt = "A driving scene during the night, with rainy weather in boston-seaport"
# prompt = batch['caption']
pseudo_sample = batch['conditioning'][:14]
# Define a simple torch generator
generator = torch.Generator().manual_seed(42)
image = batch['reference_image']

pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16", ignore_mismatched_sizes=False
)
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")


# Getting the models
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
unet_weights = pipe.unet.state_dict()
testing_unet = UNetSpatioTemporalConditionModel()
testing_unet.load_state_dict(unet_weights)
control_net_blank = SpatioTemporalControlNet.from_unet(testing_unet).half()
testing_unet = testing_unet.half()

pipe_zero = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16", ignore_mismatched_sizes=False
)

pipe_zero_with_controlnet = StableVideoDiffusionPipelineWithControlNet(
    vae = pipe.vae,
    image_encoder = pipe.image_encoder,
    unet=pipe_zero.unet.to(dtype=torch.float16, device=torch.device("cuda")),
    scheduler=pipe.scheduler,
    feature_extractor=pipe.feature_extractor,
    controlnet=control_net_blank,
    tokenizer = tokenizer,
    text_encoder = text_encoder
)

pipe_zero_with_controlnet = pipe_zero_with_controlnet.to(dtype=torch.float16, device=torch.device("cuda"))    

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
Keyword arguments {'ignore_mismatched_sizes': False} are not expected by StableVideoDiffusionPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  8.20it/s]
  torch.utils._pytree._register_pytree_node(
Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 12.87it/s]
Keyword arguments {'ignore_mismatched_sizes': False} are not expected by StableVideoDiffusionPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  6.10it/s]


VAE scale factor: 8


In [2]:
frames_zero_with_controlnet = pipe_zero_with_controlnet(height=320,width=512, image=image,num_frames = 14, prompt=prompt, conditioning_image = pseudo_sample,  decode_chunk_size=8, generator=generator).frames[0]
export_to_video(frames_zero_with_controlnet, "pod.mp4", fps=7)


  return F.conv3d(
  return F.conv2d(
100%|██████████| 25/25 [00:13<00:00,  1.92it/s]


'pod.mp4'