In [19]:

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from einops import rearrange


def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

class ControlNetConditioningEmbeddingSVD(nn.Module):
    """
    Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
    [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
    training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
    convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
    (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
    model) to encode image-space conditions ... into feature maps ..."
    """

    def __init__(
        self,
        conditioning_embedding_channels: int,
        conditioning_channels: int = 1,
        block_out_channels: Tuple[int, ...] = (32, 64, 128, 256),
    ):
        super().__init__()


        
        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        self.conv_out = zero_module(
            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
        )

    def forward(self, conditioning):
        #this seeems appropriate? idk if i should be applying a more complex setup to handle the frames
        #combine batch and frames dimensions
        batch_size, frames, channels, height, width = conditioning.size()
        conditioning = conditioning.view(batch_size * frames, channels, height, width)

        embedding = self.conv_in(conditioning)
        embedding = F.silu(embedding)

        for block in self.blocks:
            embedding = block(embedding)
            embedding = F.silu(embedding)

        embedding = self.conv_out(embedding)
        
        #split them apart again
        #actually not needed
        #new_channels, new_height, new_width = embedding.shape[1], embedding.shape[2], embedding.shape[3]
        #embedding = embedding.view(batch_size, frames, new_channels, new_height, new_width)


        return embedding

In [25]:
import torch
import torch.nn.functional as F
from einops import rearrange # Make sure to import this

# --- Your Setup ---
# Assuming controlnet_cond_embedding is a defined nn.Module
# def controlnet_cond_embedding(x): return x # Example placeholder

# Your initial tensor
controlnet_cond = torch.rand([1, 5, 6, 375, 375])
B, T = controlnet_cond.shape[0], controlnet_cond.shape[1] # Get Batch and Time
target_h, target_w = 320, 512

print(f"Original shape: {controlnet_cond.shape}")

# --- Start Correction ---
# 1. Flatten Batch and Time dimensions for 2D interpolation:
# Shape changes from [B, T, C, H, W] -> [B*T, C, H, W]
controlnet_cond_flat = rearrange(controlnet_cond, 'b t c h w -> (b t) c h w')

# 2. Apply 2D spatial interpolation
controlnet_cond_resized = F.interpolate(
    controlnet_cond_flat,
    size=(target_h, target_w), # Target (H, W)
    mode='bilinear',          # Use bilinear interpolation
    align_corners=False       # Recommended setting
)

# 3. Reshape back to 5D:
# Shape changes from [B*T, C, H_new, W_new] -> [B, T, C, H_new, W_new]
controlnet_cond = rearrange(controlnet_cond_resized, '(b t) c h w -> b t c h w', b=B)
# --- End Correction ---

print(f"Shape after resize: {controlnet_cond.shape}")


conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (32, 64, 128, 256)
controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD(
            conditioning_embedding_channels=320,
            block_out_channels=conditioning_embedding_out_channels,
            conditioning_channels=6)

controlnet_cond = controlnet_cond_embedding(controlnet_cond)
print(controlnet_cond.shape)

Original shape: torch.Size([1, 5, 6, 375, 375])
Shape after resize: torch.Size([1, 5, 6, 320, 512])
torch.Size([5, 320, 40, 64])


In [None]:
conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (32, 64, 128, 256)
controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD(
            conditioning_embedding_channels=320,
            block_out_channels=conditioning_embedding_out_channels,
            conditioning_channels=6)

controlnet_cond = controlnet_cond_embedding(controlnet_cond)
print(controlnet_cond.shape)

torch.Size([5, 320, 47, 47])


In [29]:
conv_in = nn.Conv2d(
    8,
    320,
    kernel_size=3,
    padding=1,
)
input = torch.rand(([1, 5, 8, 46, 46]))
sample = input.flatten(0, 1)
print(sample.shape)
k = conv_in(sample)
print(k.shape)


import torch
import torch.nn.functional as F

# Your two tensors with mismatched H/W
sample = torch.randn(5, 320, 46, 46)
events_condition = torch.randn(5, 320, 40, 64)

print(f"Original 'sample' shape:     {sample.shape}")
print(f"Original 'events' shape: {events_condition.shape}")

# 1. Get the target H/W from the 'events_condition' tensor
target_height = events_condition.shape[2]  # 40
target_width = events_condition.shape[3]   # 64

# 2. Resize 'sample' to match the target size
sample_resized = F.interpolate(
    sample, 
    size=(target_height, target_width), # (40, 64)
    mode='bilinear', 
    align_corners=False
)

print(f"Resized 'sample' shape:    {sample_resized.shape}")

# Now they are the same size and can be combined
combined_tensor = sample_resized + events_condition
print(f"Combined shape:          {combined_tensor.shape}")

torch.Size([5, 8, 46, 46])
torch.Size([5, 320, 46, 46])
Original 'sample' shape:     torch.Size([5, 320, 46, 46])
Original 'events' shape: torch.Size([5, 320, 40, 64])
Resized 'sample' shape:    torch.Size([5, 320, 40, 64])
Combined shape:          torch.Size([5, 320, 40, 64])
