# Imports

In [None]:
import sys
from pathlib import Path

import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel

In [None]:
parent_dir = str(Path.cwd().parent)
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

In [None]:
from utils.models import VideoTimeEncoding

In [None]:
torch.set_grad_enabled(False)

# Models

In [None]:
cross_attention_dim = 128

net = UNet2DConditionModel(
    sample_size=128,
    in_channels=3,
    out_channels=3,
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
    up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
    block_out_channels=(64, 128, 256),
    layers_per_block=2,
    dropout=0,
    act_fn="silu",
    attention_head_dim=8,
    cross_attention_dim=cross_attention_dim,
).to("cuda")
net

In [None]:
video_time_encoding = VideoTimeEncoding(
    encoding_dim=128,
    time_embed_dim=cross_attention_dim,
    flip_sin_to_cos=True,
    downscale_freq_shift=1,
).to("cuda")
video_time_encoding

# Forward pass

In [None]:
batch_size = 16

In [None]:
noisy_batch = torch.randn(batch_size, 3, 128, 128).to("cuda")
diff_timesteps = torch.randint(0, 1000, (batch_size,)).to("cuda")

In [None]:
video_time_codes = video_time_encoding.forward(0.72, batch_size).unsqueeze(1)

print(noisy_batch.shape, diff_timesteps.shape, video_time_codes.shape)
pred = net.forward(noisy_batch, diff_timesteps, encoder_hidden_states=video_time_codes, return_dict=False)[0]
pred.shape