In [1]:
import debugpy
import gc
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from typing import Optional, Tuple, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from diffusers.models.unets import UNetSpatioTemporalConditionModel
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
from diffusers.models.unets import UNetSpatioTemporalConditionModel
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion_with_controlnet import StableVideoDiffusionPipelineWithControlNet,SpatioTemporalControlNet, CustomConditioningNet, SpatioTemporalControlNetOutput
from diffusers.image_processor import VaeImageProcessor

from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion_with_controlnet import wrapperModel,StableVideoDiffusionPipelineWithWrapper, StableVideoDiffusionPipelineWithControlNet,StableVideoDiffusionPipelineWithControlNet, SpatioTemporalControlNet, CustomConditioningNet, SpatioTemporalControlNetOutput

import gc
from diffusers import DiffusionPipeline

from types import MethodType
torch.cuda.empty_cache()


import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import json


class DiffusionDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.transform = transforms.Compose([
            transforms.Resize((320, 512)),
            transforms.CenterCrop((320, 512)),
        ])
        self.image_processor = VaeImageProcessor(vae_scale_factor=8)

    def __len__(self):
        # Assuming each set of ground truths represents a separate sample
        return len(self.data['ground_truth'])

    def __getitem__(self, idx):
        # Processing ground truth images
        
        ground_truth_images = [self.transform(Image.open(path)) for path in self.data['ground_truth'][idx]]
        ground_truth_images = self.image_processor.preprocess(image = ground_truth_images, height = 320, width = 512)

        # Processing conditioning images set one (assuming RGB, 4 channels after conversion)
        conditioning_images_one = [self.transform(Image.open(path)) for path in self.data['conditioning_images_one'][idx]]
        conditioning_images_one = self.image_processor.preprocess(image = conditioning_images_one, height = 320, width = 512)

        # Processing conditioning images set two (assuming grayscale, converted to RGB to match dimensions)
        conditioning_images_two = [self.transform(Image.open(path)) for path in self.data['conditioning_images_two'][idx]]
        conditioning_images_two = self.image_processor.preprocess(image = conditioning_images_two, height = 320, width = 512)
        
        # Concatenating condition one and two images along the channel dimension
        conditioned_images = [torch.cat((img_one, img_two), dim=0) for img_one, img_two in zip(conditioning_images_one, conditioning_images_two)]

        # Processing reference images (single per scene, matched by index)
        reference_image = self.transform(Image.open(self.data['reference_image'][idx][0]))

        # Retrieving the corresponding caption
        caption = self.data['caption'][idx][0]

        

        return {
            "ground_truth": ground_truth_images,
            "conditioning": torch.stack(conditioned_images),
            "caption": caption,
            "reference_image": reference_image
        }

def collate_fn(batch):
    ground_truth = torch.stack([item['ground_truth'] for item in batch])
    conditioning = torch.stack([item['conditioning'] for item in batch])
    captions = [item['caption'] for item in batch]  # List of strings, no need to stack
    reference_images = [item['reference_image'] for item in batch]
    

    return {
        "ground_truth": ground_truth.flatten(0, 1),
        "conditioning": conditioning.flatten(0, 1),
        "caption": captions[0],
        "reference_image": reference_images[0],
    }



train_dataset = DiffusionDataset(json_path='/home/wisley/custom_diffusers_library/src/diffusers/jasper/complete_data_paths.json')

train_dataloader = DataLoader(
    train_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=1,  # Or your preferred batch size
    num_workers=0,  # Adjust based on your setup
)

for i, batch in enumerate(train_dataloader):
    to_tensor = transforms.ToTensor()
    print(f"Batch {i} has {batch['ground_truth'].shape[0]} samples")
    print(f"Caption: {batch['caption']}")
    print(f"Reference image shape: {to_tensor(batch['reference_image']).shape}")
    print(f"Conditioning image shape: {batch['conditioning'].shape}")

    if i == 2:
        break




  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


Batch 0 has 14 samples
Caption: A driving scene during the day, with clear weather in singapore-onenorth
Reference image shape: torch.Size([3, 320, 512])
Conditioning image shape: torch.Size([14, 4, 320, 512])
Batch 1 has 14 samples
Caption: A driving scene during the day, with clear weather in singapore-onenorth
Reference image shape: torch.Size([3, 320, 512])
Conditioning image shape: torch.Size([14, 4, 320, 512])
Batch 2 has 14 samples
Caption: A driving scene during the day, with clear weather in singapore-onenorth
Reference image shape: torch.Size([3, 320, 512])
Conditioning image shape: torch.Size([14, 4, 320, 512])


In [2]:
# Importing the pipelines

pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16", ignore_mismatched_sizes=False
)
config = {
    "output_size": (40, 64), 
    "num_channels": 4
}
# Create a wrapper model
custom_conditioning_net = CustomConditioningNet(**config)
model = pipe.unet
wrapper_model = wrapperModel(customConditioningNet=custom_conditioning_net, model = model)
# Create a wrapper pipeline
pipe_with_wrapper = StableVideoDiffusionPipelineWithWrapper(
    vae = pipe.vae,
    image_encoder = pipe.image_encoder,
    scheduler=pipe.scheduler,
    feature_extractor=pipe.feature_extractor,
    wrapper = wrapper_model
)

prompt = "A driving scene during the night, with rainy weather in boston-seaport"
# prompt = batch['caption']
pseudo_sample = batch['conditioning'].to(dtype=torch.float16, device=torch.device("cuda"))
# Define a simple torch generator
generator = torch.Generator().manual_seed(42)
image = batch['reference_image']

pipe = pipe.to(dtype=torch.float16, device=torch.device("cuda"))

honden = pipe_with_wrapper(height=320,width=512, image=image,conditioning_image=pseudo_sample ,num_frames = 14,  decode_chunk_size=8, generator=generator).frames[0]
export_to_video(honden, "nolatents@.mp4", fps=7)

Keyword arguments {'ignore_mismatched_sizes': False} are not expected by StableVideoDiffusionPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 5/5 [00:01<00:00,  4.78it/s]


VAE scale factor: 8


  0%|          | 0/25 [00:00<?, ?it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


  4%|▍         | 1/25 [00:00<00:10,  2.20it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


  8%|▊         | 2/25 [00:00<00:09,  2.31it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 12%|█▏        | 3/25 [00:01<00:09,  2.33it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 16%|█▌        | 4/25 [00:01<00:08,  2.35it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 20%|██        | 5/25 [00:02<00:08,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 24%|██▍       | 6/25 [00:02<00:08,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 28%|██▊       | 7/25 [00:02<00:07,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 32%|███▏      | 8/25 [00:03<00:07,  2.37it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 36%|███▌      | 9/25 [00:03<00:06,  2.37it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 40%|████      | 10/25 [00:04<00:06,  2.37it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 44%|████▍     | 11/25 [00:04<00:05,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 48%|████▊     | 12/25 [00:05<00:05,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 52%|█████▏    | 13/25 [00:05<00:05,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 56%|█████▌    | 14/25 [00:05<00:04,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 60%|██████    | 15/25 [00:06<00:04,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 64%|██████▍   | 16/25 [00:06<00:03,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 68%|██████▊   | 17/25 [00:07<00:03,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 72%|███████▏  | 18/25 [00:07<00:02,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 76%|███████▌  | 19/25 [00:08<00:02,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 80%|████████  | 20/25 [00:08<00:02,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 84%|████████▍ | 21/25 [00:08<00:01,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 88%|████████▊ | 22/25 [00:09<00:01,  2.36it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 92%|█████████▏| 23/25 [00:09<00:00,  2.37it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


 96%|█████████▌| 24/25 [00:10<00:00,  2.37it/s]

 this is the shape of the latent model input: torch.Size([2, 14, 4, 40, 64]) 
this si teh shape of the conditioning torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x torch.Size([2, 14, 4, 40, 64])
this si teh shape of the x after the concatenation torch.Size([2, 14, 8, 40, 64])
 this si the value of the noise pred: torch.Size([2, 14, 4, 40, 64])


100%|██████████| 25/25 [00:10<00:00,  2.36it/s]


'nolatents@.mp4'

In [8]:
batch['conditioning'].shape

torch.Size([14, 4, 320, 512])