# Temporal unet


In [1]:
import debugpy
from typing import Optional, Tuple, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from diffusers.models.unets import UNetSpatioTemporalConditionModel
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
from diffusers.models.unets import UNetSpatioTemporalConditionModel


from types import MethodType
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


In [2]:
pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16", num_frames = 2
)
pipe.enable_model_cpu_offload()

Keyword arguments {'num_frames': 2} are not expected by StableVideoDiffusionPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00, 10.15it/s]


In [3]:
pipe_config = pipe.config
print(pipe_config)
unet_weights = pipe.unet.state_dict()
my_net = UNetSpatioTemporalConditionModel()
my_net.load_state_dict(unet_weights)

FrozenDict([('vae', ('diffusers', 'AutoencoderKLTemporalDecoder')), ('image_encoder', ('transformers', 'CLIPVisionModelWithProjection')), ('unet', ('diffusers', 'UNetSpatioTemporalConditionModel')), ('scheduler', ('diffusers', 'EulerDiscreteScheduler')), ('feature_extractor', ('transformers', 'CLIPImageProcessor')), ('_name_or_path', 'stabilityai/stable-video-diffusion-img2vid-xt')])


<All keys matched successfully>

In [4]:
torch.cuda.empty_cache()
def prepare_latents(
    batch_size,
    num_frames,
    num_channels_latents,
    height,
    width,
    dtype,
    device,
    generator,
    latents=None,
):
    shape = (
        batch_size,
        num_frames,
        num_channels_latents // 2,
        height // 1,
        width // 1,
    )
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )

    if latents is None:
        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    else:
        latents = latents.to(device)

    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * 0.2
    return latents

def pseudo_image_embeddings( shape, generator, device, dtype, do_classifier_free_guidance = True ):
    image_embeddings = randn_tensor(shape, generator=generator, device= device, dtype=dtype)

    if do_classifier_free_guidance:
        negative_image_embeddings = torch.zeros_like(image_embeddings)

        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        return torch.cat([negative_image_embeddings, image_embeddings])

def get_add_time_ids(
  fps = 7,
  motion_bucket_id = 127,
  noise_aug_strength = 0.02,
  dtype = torch.float32,
  batch_size = 1,
  num_videos_per_prompt = 1,
  do_classifier_free_guidance = True,
):
  add_time_ids = [fps, motion_bucket_id, noise_aug_strength]

  add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
  add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)

  if do_classifier_free_guidance:
      add_time_ids = torch.cat([add_time_ids, add_time_ids])

  return add_time_ids

dtype = torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = torch.Generator().manual_seed(42)  # For reproducibility

# Generate original latents with specified dtype
my_latents = prepare_latents(1, 1, 8, 64, 64, dtype, device, generator)

# Apply classifier-free guidance by duplicating the latents and ensuring the correct dtype
latent_model_input = torch.cat([my_latents] * 2)  # Inherits dtype from my_latents

# Create pseudo image latents by cloning the original latents
pseudo_image_latents = latent_model_input.clone()  # Inherits dtype

# Concatenate pseudo image latents over the channels dimension, ensuring matching dtype
latent_model_input = torch.cat([latent_model_input, pseudo_image_latents], dim=2)


# Create the fake image embeddings with the specified dtype
hidden_image_embeddings = pseudo_image_embeddings((1, 1, 1024), generator, device, dtype)

added_time_ids = get_add_time_ids(dtype=dtype).to(device)

# Verify the dtype of both tensors
print(f"Latent model input dtype: {latent_model_input.dtype}")
print(f"Hidden image embeddings dtype: {hidden_image_embeddings.dtype}")

print(latent_model_input.shape)
print(hidden_image_embeddings.shape)
print(added_time_ids.shape)

# Print on which model they are
print(f"Latent model input is on: {latent_model_input.device}")
print(f"Hidden image embeddings are on: {hidden_image_embeddings.device}")
# Assuming added_time_ids is also a tensor; replace this with the actual tensor variable if different
print(f"Added time IDs are on: {added_time_ids.device}")

Latent model input dtype: torch.float16
Hidden image embeddings dtype: torch.float16
torch.Size([2, 1, 8, 64, 64])
torch.Size([2, 1, 1024])
torch.Size([2, 3])
Latent model input is on: cuda:0
Hidden image embeddings are on: cuda:0
Added time IDs are on: cuda:0


In [5]:
my_net = my_net.half()
my_net.to(device)

UNetSpatioTemporalConditionModel(
  (conv_in): Conv2d(8, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): LoRACompatibleLinear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
  )
  (add_time_proj): Timesteps()
  (add_embedding): TimestepEmbedding(
    (linear_1): LoRACompatibleLinear(in_features=768, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlockSpatioTemporal(
      (attentions): ModuleList(
        (0-1): 2 x TransformerSpatioTemporalModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=320, out_features=320, bias=True)
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlo

In [6]:
torch.cuda.empty_cache()

In [7]:
with torch.no_grad():

    noise_pred = my_net.forward(
        latent_model_input.to(dtype=dtype),
        torch.tensor(1).to(dtype=dtype, device=device),
        encoder_hidden_states=hidden_image_embeddings.to(dtype=dtype),
        added_time_ids=added_time_ids.to(dtype=dtype),
        down_block_additional_residuals= None,
        mid_block_additional_residual = None,
        return_dict=False,
    )[0]

    print(noise_pred.shape)
    if noise_pred is not None:
        del noise_pred



This is the batch size 2
Res samples shape unet: torch.Size([2, 320, 64, 64])
Res samples shape unet: torch.Size([2, 640, 32, 32])
Res samples shape unet: torch.Size([2, 1280, 16, 16])
Res samples shape unet: torch.Size([2, 1280, 8, 8])
Down block res samples shape uuuuuuuu: torch.Size([2, 320, 64, 64])
Down block res samples shape uuuuuuuu: torch.Size([2, 320, 64, 64])
Down block res samples shape uuuuuuuu: torch.Size([2, 320, 64, 64])
Down block res samples shape uuuuuuuu: torch.Size([2, 320, 32, 32])
Down block res samples shape uuuuuuuu: torch.Size([2, 640, 32, 32])
Down block res samples shape uuuuuuuu: torch.Size([2, 640, 32, 32])
Down block res samples shape uuuuuuuu: torch.Size([2, 640, 16, 16])
Down block res samples shape uuuuuuuu: torch.Size([2, 1280, 16, 16])
Down block res samples shape uuuuuuuu: torch.Size([2, 1280, 16, 16])
Down block res samples shape uuuuuuuu: torch.Size([2, 1280, 8, 8])
Down block res samples shape uuuuuuuu: torch.Size([2, 1280, 8, 8])
Down block res 

# Controlnet Initialization

In [8]:
from dataclasses import dataclass
from typing import Optional, Tuple, Union, Dict, Any

class UNetSpatioTemporalConditionOutput(BaseOutput):
    """
    The output of [`UNetSpatioTemporalConditionModel`].

    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """


    sample: torch.FloatTensor = None


class SpatioTemporalControlNetOutput(BaseOutput):
    """
    The output of [`ControlNetModel`].

    Args:
        down_block_res_samples (`tuple[torch.Tensor]`):
            A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
            be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
            used to condition the original UNet's downsampling activations.
        mid_down_block_re_sample (`torch.Tensor`):
            The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
            `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
            Output can be used to condition the original UNet's middle block activation.
    """

    down_block_res_samples: Tuple[torch.Tensor]
    mid_block_res_sample: torch.Tensor
    
    # Add a class which prints the sizes of the tensors
    def print_sizes(self):
        print(f"Down block res samples: {self.down_block_res_samples[0].shape}")
        print(f"Mid block res sample: {self.mid_block_res_sample.shape}")
        





class SpatioTemporalControlNet(ModelMixin, ConfigMixin):
    """
    A SpatioTemporalControlNet model for conditioning on spatio-temporal data.
    This model adapts concepts from both ControlNetModel and UNetSpatioTemporalConditionModel,
    focusing on handling video frames over time.
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        sample_size: Optional[int] = None,
        in_channels: int = 8,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlockSpatioTemporal",
            "CrossAttnDownBlockSpatioTemporal",
            "CrossAttnDownBlockSpatioTemporal",
            "DownBlockSpatioTemporal",
        ),
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        addition_time_embed_dim: int = 256,
        projection_class_embeddings_input_dim: int = 768,
        layers_per_block: Union[int, Tuple[int]] = 2,
        cross_attention_dim: Union[int, Tuple[int]] = 1024,
        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
        num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
    ):
        super().__init__()

        self.sample_size = sample_size


        # input
        self.conv_in = nn.Conv2d(
            in_channels,
            block_out_channels[0],
            kernel_size=3,
            padding=1,
        )

        # time
        time_embed_dim = block_out_channels[0] * 4

        self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
        timestep_input_dim = block_out_channels[0]

        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

        self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
        self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)


        output_channel = block_out_channels[0]
        
        self.controlnet_down_blocks = None
        self.down_blocks = nn.ModuleList([])


        
        if isinstance(num_attention_heads, int):
            num_attention_heads = (num_attention_heads,) * len(down_block_types)

        if isinstance(cross_attention_dim, int):
            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(down_block_types)

        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

        blocks_time_embed_dim = time_embed_dim

        # Initialize the connection between the down blocks and the unet
        output_channel = block_out_channels[0]

        # controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
        # controlnet_block = zero_module(controlnet_block)
        # self.controlnet_down_blocks.append(controlnet_block)

        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=blocks_time_embed_dim,
                num_layers=layers_per_block[i],
                transformer_layers_per_block=transformer_layers_per_block[i],
                add_downsample= not is_final_block,
                resnet_eps=1e-5,
                down_block_type=down_block_type,
                cross_attention_dim=cross_attention_dim[i],
                num_attention_heads=num_attention_heads[i],
                resnet_act_fn="silu",
            )
            self.down_blocks.append(down_block)

            # for _ in range(layers_per_block[i]):
            #     controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
            #     controlnet_block = zero_module(controlnet_block)
            #     self.controlnet_down_blocks.append(controlnet_block)

            # if not is_final_block:
            #     controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
            #     controlnet_block = zero_module(controlnet_block)
            #     self.controlnet_down_blocks.append(controlnet_block)




        # hardcoded_controlnet_block_dims = [320,320, 640,640, 1280, 1280, 1280,1280,1280]
        # for index, controlnet_block_dim in enumerate(hardcoded_controlnet_block_dims):
        #     controlnet_block = nn.Conv2d(controlnet_block_dim, controlnet_block_dim, kernel_size=1)
        #     controlnet_block = zero_module(controlnet_block)
        #     self.controlnet_down_blocks.append(controlnet_block)


        # Connections for the mid block
        mid_block_channel = block_out_channels[-1]

        controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
        controlnet_block = zero_module(controlnet_block)
        self.controlnet_mid_block = controlnet_block


        # mid
        self.mid_block = UNetMidBlockSpatioTemporal(
            block_out_channels[-1],
            temb_channels=blocks_time_embed_dim,
            transformer_layers_per_block=transformer_layers_per_block[-1],
            cross_attention_dim=cross_attention_dim[-1],
            num_attention_heads=num_attention_heads[-1],
        )

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        added_time_ids: torch.Tensor,
        return_dict: bool = True,
    ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
        r"""
        The [`UNetSpatioTemporalConditionModel`] forward method.

        Args:
            sample (`torch.FloatTensor`):
                The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
            encoder_hidden_states (`torch.FloatTensor`):
                The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
            added_time_ids: (`torch.FloatTensor`):
                The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
                embeddings and added to the time embeddings.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
                tuple.
        Returns:
            [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
                If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
                a `tuple` is returned where the first element is the sample tensor.
        """





        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        batch_size, num_frames = sample.shape[:2]
        timesteps = timesteps.expand(batch_size)

        t_emb = self.time_proj(timesteps)

        # `Timesteps` does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)

        emb = self.time_embedding(t_emb)

        time_embeds = self.add_time_proj(added_time_ids.flatten())
        time_embeds = time_embeds.reshape((batch_size, -1))
        time_embeds = time_embeds.to(emb.dtype)
        aug_emb = self.add_embedding(time_embeds)
        emb = emb + aug_emb

        # Flatten the batch and frames dimensions
        # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
        sample = sample.flatten(0, 1)
        # Repeat the embeddings num_video_frames times
        # emb: [batch, channels] -> [batch * frames, channels]
        emb = emb.repeat_interleave(num_frames, dim=0).to(sample.device)
        # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
        # Let encoder_hidden_states be just zeros in the correct format
        # encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
        shape_encoder_hidden_states = (batch_size * num_frames, 1, 1024)
        encoder_hidden_states = torch.zeros(shape_encoder_hidden_states, device=sample.device).repeat_interleave(num_frames, dim=0).to(dtype=sample.dtype)
        print(f"Shape of encoder hidden states: {encoder_hidden_states.shape}")
        

        # 2. pre-process
        sample = self.conv_in(sample)

        # print the device where the sample is
        print(f"Sample is on: {sample.device}")


        # Print the shape of the encoder_hidden_states
        print(f"Encoder hidden states shape: {encoder_hidden_states.shape}")

        image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)

        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                )
            else:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    image_only_indicator=image_only_indicator,
                )
            # Print the shapes of the res_samples
            down_block_res_samples += res_samples
            print(f"Res samples shape control: {res_samples[0].shape}")

        # Print the length of the down_block_res_samples
        print(f"Length of down_block_res_samples control: {len(down_block_res_samples)}")

        # 4. mid
        sample = self.mid_block(
            hidden_states=sample,
        temb=emb,
            encoder_hidden_states=encoder_hidden_states,
            image_only_indicator=image_only_indicator,
        )

        
        # 5. Control net blocks

        # initialize the controlnet_down_block_res_samples of it is on embpy nn.ModuleList
        if self.controlnet_down_blocks is None:
            self.controlnet_down_blocks = nn.ModuleList([])


            for down_block_res_sample in down_block_res_samples:
                # Determine the current number of channels in the tensor
                current_channels = down_block_res_sample.size(1)
                
                # Dynamically create a zero convolution block for the current tensor
                controlnet_block = nn.Conv2d(current_channels, current_channels, kernel_size=1)
                controlnet_block = zero_module(controlnet_block).to(down_block_res_sample.device, dtype=sample.dtype)
            
                
                # Store the processed sample for further use
                self.controlnet_down_blocks.append(controlnet_block)
    

        controlnet_down_block_res_samples = ()

        for index , (down_block_res_sample, controlnet_block) in enumerate(zip(down_block_res_samples, self.controlnet_down_blocks)):

            # print to the debug console the device where the down_block_res_sample is
            try:
                # print the size of the down_block_res_sample
                print(f"Down block res sample shape before the conversion: {down_block_res_sample.shape}")
                down_block_res_sample = controlnet_block(down_block_res_sample)
                controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
            except Exception as e:
                # Print the error in conjecuntion with the index
                print(f"Error at index {index}: {e}")

        down_block_res_samples = controlnet_down_block_res_samples

        mid_block_res_sample = self.controlnet_mid_block(sample)

        down_block_res_samples = [sample for sample in down_block_res_samples]
        mid_block_res_sample = mid_block_res_sample

        if not return_dict:
            return (down_block_res_samples, mid_block_res_sample)

        return SpatioTemporalControlNetOutput(
            down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
        )

    @classmethod
    def from_unet(
        cls,
        unet: UNetSpatioTemporalConditionModel,
        load_weights_from_unet: bool = True,
    ):
        
        addition_time_embed_dim = (
            unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
        )

        controlnet = cls(
            in_channels=unet.config.in_channels,
            down_block_types=unet.config.down_block_types,
            block_out_channels=unet.config.block_out_channels,  # What are block out channels
            addition_time_embed_dim=addition_time_embed_dim,
            projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
            layers_per_block=unet.config.layers_per_block,
            cross_attention_dim=unet.config.cross_attention_dim,
            transformer_layers_per_block=unet.config.transformer_layers_per_block,
            num_attention_heads=unet.config.num_attention_heads,
        )

        if load_weights_from_unet:
            controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
            controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
            controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())

            controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
            controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())

        return controlnet


def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

In [9]:
# initialize the contrl net from my_net
control_net = SpatioTemporalControlNet.from_unet(my_net)
control_net = control_net.half()
control_net.to(device)

SpatioTemporalControlNet(
  (conv_in): Conv2d(8, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): LoRACompatibleLinear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
  )
  (add_time_proj): Timesteps()
  (add_embedding): TimestepEmbedding(
    (linear_1): LoRACompatibleLinear(in_features=768, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlockSpatioTemporal(
      (attentions): ModuleList(
        (0-1): 2 x TransformerSpatioTemporalModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=320, out_features=320, bias=True)
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
    

In [10]:
with torch.no_grad():
    time =  torch.tensor(1).to(dtype=dtype)

    # print on which device the input is
    print(f"Latent model input is on: {latent_model_input.device}")
    print(f"Hidden image embeddings are on: {hidden_image_embeddings.device}")
    print(f"Added time IDs are on: {added_time_ids.device}")
    print(f"Time is on: {time.device}")

    # move time to divice
    time = time.to(device)
    print(f"Time is on: {time.device}")
           
    noise_pred = control_net.forward(
        latent_model_input.to(dtype=dtype),
        time,
        encoder_hidden_states=hidden_image_embeddings.to(dtype=dtype),
        added_time_ids=added_time_ids.to(dtype=dtype),
        return_dict=True,
    )    
    

    # Print the sizes of the tensors
 
    if noise_pred is not None:
        del noise_pred


Latent model input is on: cuda:0
Hidden image embeddings are on: cuda:0
Added time IDs are on: cuda:0
Time is on: cpu
Time is on: cuda:0
Shape of encoder hidden states: torch.Size([2, 1, 1024])
Sample is on: cuda:0
Encoder hidden states shape: torch.Size([2, 1, 1024])
Res samples shape control: torch.Size([2, 320, 64, 64])
Res samples shape control: torch.Size([2, 640, 32, 32])
Res samples shape control: torch.Size([2, 1280, 16, 16])
Res samples shape control: torch.Size([2, 1280, 8, 8])
Length of down_block_res_samples control: 12
Down block res sample shape before the conversion: torch.Size([2, 320, 64, 64])
Down block res sample shape before the conversion: torch.Size([2, 320, 64, 64])
Down block res sample shape before the conversion: torch.Size([2, 320, 64, 64])
Down block res sample shape before the conversion: torch.Size([2, 320, 32, 32])
Down block res sample shape before the conversion: torch.Size([2, 640, 32, 32])
Down block res sample shape before the conversion: torch.Size(

In [11]:


with torch.no_grad():
    (down_block_res_samples, mid_block_res_samples) = control_net.forward(
        latent_model_input.to(dtype=dtype),
        time,
        encoder_hidden_states=hidden_image_embeddings.to(dtype=dtype),
        added_time_ids=added_time_ids.to(dtype=dtype),
        return_dict=False,
    )

    # reverse the down_block_res_samples tuple
    # down_block_res_samples = down_block_res_samples[::-1]

    print(f"Length of down_block_res_samples afterprocesssing: {len(down_block_res_samples)}")


    noise_pred = my_net.forward(
        latent_model_input.to(dtype=dtype),
        torch.tensor(1).to(dtype=dtype, device=device),
        encoder_hidden_states=hidden_image_embeddings.to(dtype=dtype),
        # Maybe I need to reverse the order of the tensors
        down_block_additional_residuals= down_block_res_samples,
        mid_block_additional_residual = mid_block_res_samples,
        added_time_ids=added_time_ids.to(dtype=dtype),
        return_dict=False,
    )[0]

    # Print the sizes of the tensors
    if noise_pred is not None:
        del noise_pred


Shape of encoder hidden states: torch.Size([2, 1, 1024])
Sample is on: cuda:0
Encoder hidden states shape: torch.Size([2, 1, 1024])
Res samples shape control: torch.Size([2, 320, 64, 64])
Res samples shape control: torch.Size([2, 640, 32, 32])


Res samples shape control: torch.Size([2, 1280, 16, 16])
Res samples shape control: torch.Size([2, 1280, 8, 8])
Length of down_block_res_samples control: 12
Down block res sample shape before the conversion: torch.Size([2, 320, 64, 64])
Down block res sample shape before the conversion: torch.Size([2, 320, 64, 64])
Down block res sample shape before the conversion: torch.Size([2, 320, 64, 64])
Down block res sample shape before the conversion: torch.Size([2, 320, 32, 32])
Down block res sample shape before the conversion: torch.Size([2, 640, 32, 32])
Down block res sample shape before the conversion: torch.Size([2, 640, 32, 32])
Down block res sample shape before the conversion: torch.Size([2, 640, 16, 16])
Down block res sample shape before the conversion: torch.Size([2, 1280, 16, 16])
Down block res sample shape before the conversion: torch.Size([2, 1280, 16, 16])
Down block res sample shape before the conversion: torch.Size([2, 1280, 8, 8])
Down block res sample shape before the con

In [12]:

from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion_with_controlnet import StableVideoDiffusionPipelineWithControlNet

In [14]:
pipe_with_controlnet = StableVideoDiffusionPipelineWithControlNet(
    vae = pipe.vae,
    image_encoder = pipe.image_encoder,
    unet=my_net,
    scheduler=pipe.scheduler,
    feature_extractor=pipe.feature_extractor,
    controlnet=control_net
)




pipe_with_controlnet.enable_model_cpu_offload()