In [2]:
import torch
import torch.nn as nn

# 定义 CategoricalEmbedder 类
class CategoricalEmbedder(nn.Module):
    """
    Embeds categorical conditions such as data sources into vector representations. 
    Now no dropout or noise is added, making it consistent across training and inference.
    """
    def __init__(self, input_size, hidden_size):
        super(CategoricalEmbedder, self).__init__()
        # 直接用线性层将输入映射到一个隐藏维度的向量
        self.fc = nn.Linear(input_size, hidden_size)

    def forward(self, labels):
        # labels: 输入的向量，比如[0, 0, 1]，直接通过线性变换得到嵌入
        embeddings = self.fc(labels.float())  # 转换为浮点数，因为线性层通常处理浮点数输入
        return embeddings

# 定义 AdaLN 类
class AdaLN(nn.Module):
    def __init__(self, cond_dim, hidden_dim):
        super(AdaLN, self).__init__()
        self.gamma_net = nn.Linear(cond_dim, hidden_dim)
        self.beta_net = nn.Linear(cond_dim, hidden_dim)

    def forward(self, x, c):
        # Compute gamma(c) and beta(c)
        gamma = self.gamma_net(c).unsqueeze(1)  # Shape (batch_size, 1, hidden_dim)
        beta = self.beta_net(c).unsqueeze(1)    # Shape (batch_size, 1, hidden_dim)
        
        # Layer normalization
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        
        x_normalized = (x - mean) / (std + 1e-6)
        return gamma * x_normalized + beta

# 测试 CategoricalEmbedder
print("Testing CategoricalEmbedder")

hidden_size = 4  # 嵌入向量的维度
categorical_embedder = CategoricalEmbedder(input_size=3, hidden_size=4)

labels = torch.tensor([0, 0, 1])  # 分类标签
embeddings = categorical_embedder(labels)  # 获得嵌入表示
print(f"Input labels: {labels}")
print(f"Embeddings: {embeddings}\n")

# 测试 AdaLN
print("Testing AdaLN")
cond_dim = hidden_size  # 条件向量的维度
hidden_dim = 4  # 输入特征的维度

adaln = AdaLN(cond_dim, hidden_dim)

# 假设输入特征 x 的形状是 (batch_size, seq_length, hidden_dim)
x = torch.randn(3, 5, 4)  # 输入特征
c = torch.randn(3, cond_dim)  # 条件向量（每个样本的条件向量）

output = adaln(x, c)  # 输出经过 AdaLN 调整的特征
print(f"Input features (x) shape: {x.shape}")
print(f"Condition vector (c) shape: {c.shape}")
print(f"Output shape after AdaLN: {output.shape}")

Testing CategoricalEmbedder
Input labels: tensor([0, 0, 1])
Embeddings: tensor([ 0.4218, -0.5789, -0.0714, -0.0017], grad_fn=<ViewBackward0>)

Testing AdaLN
Input features (x) shape: torch.Size([3, 5, 4])
Condition vector (c) shape: torch.Size([3, 4])
Output shape after AdaLN: torch.Size([3, 5, 4])
