# Temporal unet


In [1]:
import debugpy
import gc
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from typing import Optional, Tuple, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from diffusers.models.unets import UNetSpatioTemporalConditionModel
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
from diffusers.models.unets import UNetSpatioTemporalConditionModel


from types import MethodType
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


In [2]:
pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16", num_frames = 2
)
pipe.enable_model_cpu_offload()

Keyword arguments {'num_frames': 2} are not expected by StableVideoDiffusionPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  6.74it/s]


# Controlnet Initialization

In [4]:
# initialize the contrl net from my_net

from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion_with_controlnet import StableVideoDiffusionPipelineWithControlNet,SpatioTemporalControlNet, CustomConditioningNet, SpatioTemporalControlNetOutput

import gc
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
del pipeline
gc.collect() 

pipe_config = pipe.config
print(pipe_config)
unet_weights = pipe.unet.state_dict()
my_net = UNetSpatioTemporalConditionModel()
my_net.load_state_dict(unet_weights)
control_net = SpatioTemporalControlNet.from_unet(my_net)

Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 10.84it/s]


FrozenDict([('vae', ('diffusers', 'AutoencoderKLTemporalDecoder')), ('image_encoder', ('transformers', 'CLIPVisionModelWithProjection')), ('unet', ('diffusers', 'UNetSpatioTemporalConditionModel')), ('scheduler', ('diffusers', 'EulerDiscreteScheduler')), ('feature_extractor', ('transformers', 'CLIPImageProcessor')), ('_name_or_path', 'stabilityai/stable-video-diffusion-img2vid-xt')])


In [None]:
pipe_with_controlnet = StableVideoDiffusionPipelineWithControlNet(
    vae = pipe.vae,
    image_encoder = pipe.image_encoder,
    unet=my_net,
    scheduler=pipe.scheduler,
    feature_extractor=pipe.feature_extractor,
    controlnet=control_net,
    tokenizer = tokenizer,
    text_encoder = text_encoder
)
pipe_with_controlnet.enable_model_cpu_offload()

In [None]:
prompt = "d"
pseudo_sample = torch.randn(25, 4, 578, 1028)
# Define a simple torch generator
generator = torch.Generator().manual_seed(42)
image = load_image("/home/wisley/custom_diffusers_library/frame2.png")
frames = pipe_with_controlnet(image = image, prompt=prompt, conditioning_image = pseudo_sample,  decode_chunk_size=8, generator=generator).frames[0]


Shape of prompt embeds: torch.Size([1, 77, 1024]) 1 77


  if next(self.conditioning_embedding.parameters()).is_cuda:
  conditioning_nn = next(self.conditioning_embedding.parameters()).device
  conditioning_nn_dtype = next(self.conditioning_embedding.parameters()).dtype
  controlnet_condition =  self.conditioning_embedding.forward(controlnet_condition)
100%|██████████| 25/25 [01:34<00:00,  3.78s/it]


In [None]:
export_to_video(frames, "jajavi.mp4", fps=7)

'jajavi.mp4'

# Run the models


In [None]:
torch.cuda.empty_cache()
def prepare_latents(
    batch_size,
    num_frames,
    num_channels_latents,
    height,
    width,
    dtype,
    device,
    generator,
    latents=None,
):
    shape = (
        batch_size,
        num_frames,
        num_channels_latents // 2,
        height // 1,
        width // 1,
    )
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )

    if latents is None:
        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    else:
        latents = latents.to(device)

    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * 0.2
    return latents

def pseudo_image_embeddings( shape, generator, device, dtype, do_classifier_free_guidance = True ):
    image_embeddings = randn_tensor(shape, generator=generator, device= device, dtype=dtype)

    if do_classifier_free_guidance:
        negative_image_embeddings = torch.zeros_like(image_embeddings)

        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        return torch.cat([negative_image_embeddings, image_embeddings])

def get_add_time_ids(
  fps = 7,
  motion_bucket_id = 127,
  noise_aug_strength = 0.02,
  dtype = torch.float32,
  batch_size = 1,
  num_videos_per_prompt = 1,
  do_classifier_free_guidance = True,
):
  add_time_ids = [fps, motion_bucket_id, noise_aug_strength]

  add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
  add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)

  if do_classifier_free_guidance:
      add_time_ids = torch.cat([add_time_ids, add_time_ids])

  return add_time_ids

dtype = torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = torch.Generator().manual_seed(42)  # For reproducibility

# Generate original latents with specified dtype
my_latents = prepare_latents(1, 1, 8, 64, 64, dtype, device, generator)

# Apply classifier-free guidance by duplicating the latents and ensuring the correct dtype
latent_model_input = torch.cat([my_latents] * 2)  # Inherits dtype from my_latents

# Create pseudo image latents by cloning the original latents
pseudo_image_latents = latent_model_input.clone()  # Inherits dtype

# Concatenate pseudo image latents over the channels dimension, ensuring matching dtype
latent_model_input = torch.cat([latent_model_input, pseudo_image_latents], dim=2)


# Create the fake image embeddings with the specified dtype
hidden_image_embeddings = pseudo_image_embeddings((1, 1, 1024), generator, device, dtype)

added_time_ids = get_add_time_ids(dtype=dtype).to(device)

# Verify the dtype of both tensors
print(f"Latent model input dtype: {latent_model_input.dtype}")
print(f"Hidden image embeddings dtype: {hidden_image_embeddings.dtype}")

print(latent_model_input.shape)
print(hidden_image_embeddings.shape)
print(added_time_ids.shape)

# Print on which model they are
print(f"Latent model input is on: {latent_model_input.device}")
print(f"Hidden image embeddings are on: {hidden_image_embeddings.device}")
# Assuming added_time_ids is also a tensor; replace this with the actual tensor variable if different
print(f"Added time IDs are on: {added_time_ids.device}")

Latent model input dtype: torch.float16
Hidden image embeddings dtype: torch.float16
torch.Size([2, 1, 8, 64, 64])
torch.Size([2, 1, 1024])
torch.Size([2, 3])
Latent model input is on: cuda:0
Hidden image embeddings are on: cuda:0
Added time IDs are on: cuda:0


In [None]:
with torch.no_grad():

    noise_pred = my_net.forward(
        latent_model_input.to(dtype=dtype),
        torch.tensor(1).to(dtype=dtype, device=device),
        encoder_hidden_states=hidden_image_embeddings.to(dtype=dtype),
        added_time_ids=added_time_ids.to(dtype=dtype),
        down_block_additional_residuals= None,
        mid_block_additional_residual = None,
        return_dict=False,
    )[0]

    print(noise_pred.shape)
    if noise_pred is not None:
        del noise_pred



torch.Size([2, 1, 4, 64, 64])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
from typing import Optional, List

def encode_prompt(
    prompt,
    device,
    do_classifier_free_guidance,
    negative_prompt="Simulation, artifacts, blurry, low resolution, low quality, noisy, grainy, distorted",
    num_images_per_prompt = 1,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    lora_scale: Optional[float] = None,
    clip_skip: Optional[int] = None,
    text_encoder = None, 
    tokenizer = None):
    # Set the text_encoder and the tokenizer on the correct device  
    text_encoder = text_encoder.to(device)


    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    if True:

        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
            text_input_ids, untruncated_ids
        ):
            removed_text = tokenizer.batch_decode(
                untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
            )
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to"
                f" {tokenizer.model_max_length} tokens: {removed_text}"
            )

        if hasattr(text_encoder.config, "use_attention_mask") and text_encoder.config.use_attention_mask:
            attention_mask = text_inputs.attention_mask.to(device)
        else:
            attention_mask = None


        prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
        prompt_embeds = prompt_embeds[0]
       

    if text_encoder is not None:
        prompt_embeds_dtype = text_encoder.dtype


    prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)


    bs_embed, seq_len, _ = prompt_embeds.shape
    print(f"Shape of prompt embeds: {prompt_embeds.shape} {bs_embed} {seq_len}")
    
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

    # get unconditional embeddings for classifier free guidance
    if do_classifier_free_guidance and negative_prompt_embeds is None:
        uncond_tokens: List[str]
        if negative_prompt is None:
            uncond_tokens = [""] * batch_size
        elif prompt is not None and type(prompt) is not type(negative_prompt):
            raise TypeError(
                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                f" {type(prompt)}."
            )
        elif isinstance(negative_prompt, str):
            uncond_tokens = [negative_prompt]
        elif batch_size != len(negative_prompt):
            raise ValueError(
                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                " the batch size of `prompt`."
            )
        else:
            uncond_tokens = negative_prompt
        
        max_length = prompt_embeds.shape[1]
        uncond_input = tokenizer(
            uncond_tokens,
            padding="max_length",
            max_length=max_length,
            truncation=True,
            return_tensors="pt",
        )

        if hasattr(text_encoder.config, "use_attention_mask") and text_encoder.config.use_attention_mask:
            attention_mask = uncond_input.attention_mask.to(device)
        else:
            attention_mask = None

        negative_prompt_embeds = text_encoder(
            uncond_input.input_ids.to(device),
            attention_mask=attention_mask,
        )
        negative_prompt_embeds = negative_prompt_embeds[0]

    if do_classifier_free_guidance:
        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
        seq_len = negative_prompt_embeds.shape[1]

        negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)


    
    embeds = torch.cat([prompt_embeds, negative_prompt_embeds])
    embeds = embeds.mean(dim=1, keepdim=True)

    return embeds

prompt = "A f a cat"
prompt_embeds = encode_prompt(prompt=prompt, device=device, do_classifier_free_guidance=True, text_encoder=text_encoder, tokenizer=tokenizer)
print(prompt_embeds.shape)



Shape of prompt embeds: torch.Size([1, 77, 1024]) 1 77
torch.Size([2, 1, 1024])


In [None]:
with torch.no_grad():
    time =  torch.tensor(1).to(dtype=dtype)

    # print on which device the input is
    print(f"Latent model input is on: {latent_model_input.device}")
    print(f"Hidden image embeddings are on: {hidden_image_embeddings.device}")
    print(f"Added time IDs are on: {added_time_ids.device}")
    print(f"Time is on: {time.device}")

    # move time to divice
    time = time.to(device)
    print(f"Time is on: {time.device}")
           
    noise_pred = control_net.forward(
        torch.ones(2, 2, 8, 64, 64).to(dtype=dtype, device=device),
        time,
        added_time_ids=added_time_ids.to(dtype=dtype),
        encoder_hidden_states = prompt_embeds.to(dtype=dtype),
        # encoder_hidden_states = None,
        return_dict=True,
        # controlnet_condition = torch.ones(25, 4, 576, 1024).to(dtype=dtype, device=device)
    )    
    

    # Print the sizes of the tensors
 
    if noise_pred is not None:
        del noise_pred


Latent model input is on: cuda:0
Hidden image embeddings are on: cuda:0
Added time IDs are on: cuda:0
Time is on: cpu
Time is on: cuda:0


In [None]:


with torch.no_grad():
    (down_block_res_samples, mid_block_res_samples) = control_net.forward(
        latent_model_input.to(dtype=dtype),
        time,
        encoder_hidden_states=hidden_image_embeddings.to(dtype=dtype),
        added_time_ids=added_time_ids.to(dtype=dtype),
        return_dict=False,
    )

    # reverse the down_block_res_samples tuple
    # down_block_res_samples = down_block_res_samples[::-1]

    print(f"Length of down_block_res_samples afterprocesssing: {len(down_block_res_samples)}")


    noise_pred = my_net.forward(
        latent_model_input.to(dtype=dtype),
        torch.tensor(1).to(dtype=dtype, device=device),
        encoder_hidden_states=hidden_image_embeddings.to(dtype=dtype),
        # Maybe I need to reverse the order of the tensors
        down_block_additional_residuals= down_block_res_samples,
        mid_block_additional_residual = mid_block_res_samples,
        added_time_ids=added_time_ids.to(dtype=dtype),
        return_dict=False,
    )[0]

    # Print the sizes of the tensors
    if noise_pred is not None:
        del noise_pred


Length of down_block_res_samples afterprocesssing: 12


In [None]:
conditioning_net = CustomConditioningNet()
pseudo_sample = torch.randn(25, 4, 578, 1028)
hoi = conditioning_net.forward(pseudo_sample)
print(hoi.shape)

torch.Size([2, 25, 320, 72, 128])
