<a href="https://colab.research.google.com/github/wr0124/Learning_essential/blob/main/ResBlock.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#resblock modify in side class

In [2]:
pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [4]:
import torch
import torch.nn as nn
from einops import rearrange

class InflatedGroupNorm(nn.GroupNorm):
    def forward(self, x):
        # Extract the video length dimension
        video_length = x.shape[2]

        # Reshape (b, c, f, h, w) to (b*f, c, h, w) for 2D group normalization
        x = rearrange(x, "b c f h w -> (b f) c h w")

        # Apply group normalization
        x = super().forward(x)

        # Reshape back from (b*f, c, h w) to (b, c, f, h, w)
        x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)

        return x

def normalization(channels, norm="groupnorm32"):
    """
    Make a standard normalization layer.

    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    if "groupnorm" in norm:
        group_norm_size = int(norm.split("groupnorm")[1])
        if channels % group_norm_size != 0:
            raise ValueError(f"channels ({channels}) must be divisible by group_norm_size ({group_norm_size})")
        return InflatedGroupNorm(group_norm_size, channels)
    else:
        raise ValueError(f"Normalization method {norm} not supported.")

# Example usage
if __name__ == "__main__":
    # Example tensor of shape (batch_size, channels, frames, height, width)
    x = torch.randn(2, 64, 8, 32, 32)  # Using 64 channels to be divisible by 32  channels could be [64, 128, 256, 512, 1024, 768,384, 192 ]

    # Create an InflatedGroupNorm layer
    norm_layer = normalization(channels=64, norm="groupnorm32")

    # Apply the normalization layer
    y = norm_layer(x)

    print(y.shape)  # Should output: torch.Size([2, 64, 8, 32, 32])


torch.Size([2, 64, 8, 32, 32])


# class ResBlock modify the input

In [9]:
import torch

# Example tensor x with shape [b, c, f, h, w]
x = torch.randn(4, 3, 2, 64, 64)

# Get the dimensions of x
b, c, f, h, w = x.shape

# Reshape x into a new tensor with shape [b*f, c, h, w]
new_shape = (b * f, c, h, w)
reshaped_tensor = x.view(new_shape)

# Print the shapes to verify
print("Original shape:", x.shape)
print("Reshaped shape:", reshaped_tensor.shape)


Original shape: torch.Size([4, 3, 2, 64, 64])
Reshaped shape: torch.Size([8, 3, 64, 64])


In [11]:
import torch

# Example tensor x with shape [b, c, f, h, w]
x = torch.randn(8,3, 64, 64)

# Get the dimensions of x
bf, c, h, w = x.shape

b=1
f=bf//b
# Reshape x into a new tensor with shape [b*f, c, h, w]
new_shape = (b, c,f, h, w)
reshaped_tensor = x.view(b, c,f, h, w)

# Print the shapes to verify
print("Original shape:", x.shape)
print("Reshaped shape:", reshaped_tensor.shape)


Original shape: torch.Size([8, 3, 64, 64])
Reshaped shape: torch.Size([1, 3, 8, 64, 64])
