# Norm

In [7]:
import torch
import torch.nn as nn

# 输入形状：(N, C, H, W) = (2, 6, 3, 3)
x = torch.tensor([[[[1.]*3]*3, [[2.]*3]*3, [[3.]*3]*3, [[4.]*3]*3, [[5.]*3]*3, [[6.]*3]*3],
                  [[[11.]*3]*3, [[12.]*3]*3, [[13.]*3]*3, [[14.]*3]*3, [[15.]*3]*3, [[16.]*3]*3]])

bn = nn.BatchNorm2d(num_features=6, affine=False)
y_bn = bn(x)
print("BatchNorm 输出（第一个通道）:\n", y_bn[0, 0])

ln = nn.LayerNorm(normalized_shape=(6, 3, 3), elementwise_affine=False)
y_ln = ln(x)
print("LayerNorm 输出（第一个样本的均值）:", y_ln[0].mean().item())

in_norm = nn.InstanceNorm2d(num_features=6, affine=False)
y_in = in_norm(x)
print("InstanceNorm 输出（第一个样本的第一个通道）:\n", y_in[0, 0])

gn = nn.GroupNorm(num_groups=3, num_channels=6, affine=False)
y_gn = gn(x)
print("GroupNorm 输出（第一个样本的组0）:\n", y_gn[0, :2])

BatchNorm 输出（第一个通道）:
 tensor([[-1.0000, -1.0000, -1.0000],
        [-1.0000, -1.0000, -1.0000],
        [-1.0000, -1.0000, -1.0000]])
LayerNorm 输出（第一个样本的均值）: -8.278422569674149e-08
InstanceNorm 输出（第一个样本的第一个通道）:
 tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
GroupNorm 输出（第一个样本的组0）:
 tensor([[[-1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000]],

        [[ 1.0000,  1.0000,  1.0000],
         [ 1.0000,  1.0000,  1.0000],
         [ 1.0000,  1.0000,  1.0000]]])


# DiT

In [8]:
import torch
import torch.nn as nn
import math
from timm.models.vision_transformer import Attention, Mlp

import warnings
warnings.filterwarnings(action="ignore")

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class DitBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
        super(DitBlock, self).__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )
    
    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) * gate_msa.unsqueeze(1)
        x = x + self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) * gate_mlp.unsqueeze(1)
        return x


class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super(FinalLayer, self).__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )
    
    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = self.linear(modulate(self.norm_final(x), shift, scale))
        return x

    

In [9]:
import torch

# 定义测试参数
hidden_size = 256
num_heads = 8
mlp_ratio = 4.0
patch_size = 16
out_channels = 3

# 创建DitBlock实例
dit_block = DitBlock(hidden_size, num_heads, mlp_ratio)

# 创建FinalLayer实例
final_layer = FinalLayer(hidden_size, patch_size, out_channels)

# 创建测试输入
x = torch.randn(1, 10, hidden_size)  # batch_size=1, sequence_length=10, hidden_size=256
c = torch.randn(1, hidden_size)  # batch_size=1, hidden_size=256

# 测试DitBlock
output_dit_block = dit_block(x, c)
print("DitBlock Output Shape:", output_dit_block.shape)

# 测试FinalLayer
output_final_layer = final_layer(output_dit_block, c)
print("FinalLayer Output Shape:", output_final_layer.shape)

DitBlock Output Shape: torch.Size([1, 10, 256])
FinalLayer Output Shape: torch.Size([1, 10, 768])
