In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class Expert(nn.Module):
    """专家网络 - 替代标准FFN"""
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.net(x)

class MoEGatingNetwork(nn.Module):
    """MoE门控网络"""
    def __init__(self, dim, num_experts, k=2):
        super().__init__()
        self.k = k
        self.gate = nn.Linear(dim, num_experts)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        # x: (batch, seq_len, dim)
        logits = self.gate(x)  # (batch, seq_len, num_experts)
        weights = self.softmax(logits)
        
        # 选择top-k个专家
        topk_weights, topk_indices = torch.topk(weights, self.k, dim=-1)
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
        
        # 创建稀疏掩码
        mask = torch.zeros_like(weights).scatter(-1, topk_indices, 1)
        
        return weights, mask, topk_weights, topk_indices

class MoEFFN(nn.Module):
    """MoE前馈层 - 替代标准FFN"""
    def __init__(self, dim, hidden_dim, num_experts=4, k=2, dropout=0.1):
        super().__init__()
        self.dim = dim
        self.num_experts = num_experts
        self.k = k
        
        # 创建专家网络
        self.experts = nn.ModuleList([
            Expert(dim, hidden_dim, dropout) for _ in range(num_experts)
        ])
        
        # 创建门控网络
        self.gating = MoEGatingNetwork(dim, num_experts, k)
        
        # 辅助损失系数
        self.aux_loss_coef = 0.01
        
    def forward(self, x):
        batch_size, seq_len, dim = x.shape
        
        # 通过门控网络获取权重和掩码
        weights, mask, topk_weights, topk_indices = self.gating(x)
        
        # 初始化输出
        expert_outputs = torch.zeros(
            batch_size, seq_len, self.num_experts, self.dim, 
            device=x.device, dtype=x.dtype
        )
        
        # 计算每个专家的输出
        for i, expert in enumerate(self.experts):
            # 找到需要当前专家的token索引
            idx = torch.where(mask[..., i] == 1)
            if len(idx[0]) > 0:
                # 处理这些token
                expert_outputs[idx[0], idx[1], i] = expert(x[idx[0], idx[1]])
        
        # 聚合专家输出
        output = torch.zeros_like(x)
        for i in range(self.k):
            expert_idx = topk_indices[..., i]  # (batch, seq_len)
            for b in range(batch_size):
                for s in range(seq_len):
                    expert_id = expert_idx[b, s]
                    output[b, s] += topk_weights[b, s, i] * expert_outputs[b, s, expert_id]
        
        # 计算辅助损失（负载均衡损失）
        aux_loss = self._load_balancing_loss(weights, mask)
        
        return output, aux_loss
    
    def _load_balancing_loss(self, weights, mask):
        """计算负载均衡损失"""
        # 计算每个专家的使用率
        expert_usage = mask.float().mean(0).mean(0)  # (num_experts)
        
        # 计算每个token的专家权重
        weight_sum = weights.mean(0).mean(0)  # (num_experts)
        
        # 负载均衡损失
        aux_loss = torch.sum(expert_usage * weight_sum) * self.num_experts
        return aux_loss * self.aux_loss_coef

class Attention(nn.Module):
    """自注意力机制"""
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.1):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)
        
        self.heads = heads
        self.scale = dim_head ** -0.5
        
        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()
        
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        
        attn = self.attend(dots)
        attn = self.dropout(attn)
        
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class TransformerBlock(nn.Module):
    """Transformer块 - 使用MoE替代标准FFN"""
    def __init__(self, dim, heads, dim_head, mlp_dim, num_experts=4, k=2, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, dim_head, dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.moe_ffn = MoEFFN(dim, mlp_dim, num_experts, k, dropout)
        
    def forward(self, x):
        # 自注意力
        x = x + self.attn(self.norm1(x))
        
        # MoE前馈网络
        ffn_out, aux_loss = self.moe_ffn(self.norm2(x))
        x = x + ffn_out
        
        return x, aux_loss

class ViTMoEForRestoration(nn.Module):
    """Vision Transformer with MoE for Image Restoration"""
    def __init__(self, image_size, patch_size, dim, depth, heads, mlp_dim, 
                 num_experts=4, k=2, channels=3, dim_head=64, dropout=0.1, emb_dropout=0.1):
        super().__init__()
        image_height, image_width = image_size if isinstance(image_size, tuple) else (image_size, image_size)
        patch_height, patch_width = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
        
        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by the patch size.'
        
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.image_size = image_size
        self.channels = channels
        
        # Patch嵌入
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(2)  # (B, dim, H*W) -> (B, H*W, dim)
        )
        
        # 位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        self.dropout = nn.Dropout(emb_dropout)
        
        # Transformer层
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(TransformerBlock(
                dim, heads, dim_head, mlp_dim, num_experts, k, dropout
            ))
        
        # 输出层 - 重建图像
        self.to_pixel = nn.Sequential(
            nn.Linear(dim, patch_dim),
            nn.GELU()
        )
        
        # 上采样卷积 - 确保输出尺寸与输入一致
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(dim, channels, kernel_size=patch_size, stride=patch_size),
            nn.Tanh()  # 限制输出范围到[-1, 1]
        )
        
        # 存储辅助损失
        self.aux_losses = []
        
    def forward(self, img):
        # 提取patch嵌入
        x = self.to_patch_embedding(img)  # (B, dim, num_patches)
        x = x.permute(0, 2, 1)  # (B, num_patches, dim)
        
        b, n, _ = x.shape
        
        # 添加位置编码
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)
        
        # 通过Transformer层
        self.aux_losses = []  # 重置辅助损失
        for layer in self.layers:
            x, aux_loss = layer(x)
            self.aux_losses.append(aux_loss)
        
        # 重建图像
        # 方法1: 直接重建每个patch的像素值
        # pixels = self.to_pixel(x)  # (B, num_patches, patch_dim)
        # out = pixels.reshape(b, self.num_patches, self.channels, self.patch_size, self.patch_size)
        # out = out.permute(0, 2, 1, 3, 4)  # (B, C, num_patches, patch_size, patch_size)
        # h = w = int(self.num_patches ** 0.5)
        # out = out.reshape(b, self.channels, h, w, self.patch_size, self.patch_size)
        # out = out.permute(0, 1, 2, 4, 3, 5)  # (B, C, h, patch_size, w, patch_size)
        # out = out.reshape(b, self.channels, h * self.patch_size, w * self.patch_size)
        
        # 方法2: 使用转置卷积上采样
        x = x.permute(0, 2, 1)  # (B, dim, num_patches)
        h = w = int(self.num_patches ** 0.5)
        x = x.reshape(b, -1, h, w)  # (B, dim, h, w)
        out = self.upsample(x)  # (B, C, H, W)
        
        return out
    
    def get_aux_loss(self):
        """获取所有辅助损失的总和"""
        return sum(self.aux_losses) if self.aux_losses else torch.tensor(0.0)

# 示例使用
if __name__ == "__main__":
    # 设置随机种子
    torch.manual_seed(42)
    
    # 超参数
    image_size = 64
    patch_size = 8
    channels = 3
    dim = 256
    depth = 6
    heads = 8
    mlp_dim = 512
    num_experts = 4
    k = 2
    
    # 创建ViT-MoE图像恢复模型
    model = ViTMoEForRestoration(
        image_size=image_size,
        patch_size=patch_size,
        dim=dim,
        depth=depth,
        heads=heads,
        mlp_dim=mlp_dim,
        num_experts=num_experts,
        k=k,
        channels=channels
    )
    
    # 创建示例输入
    x = torch.randn(2, channels, image_size, image_size)
    
    # 前向传播
    output = model(x)
    aux_loss = model.get_aux_loss()
    
    print(f"输入形状: {x.shape}")
    print(f"输出形状: {output.shape}")
    print(f"辅助损失: {aux_loss.item():.4f}")
    
    # 模拟训练步骤 - 使用均方误差损失
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # 假设的目标图像 (与输入相同尺寸的恢复目标)
    target = torch.randn_like(x)
    
    # 计算主损失
    main_loss = criterion(output, target)
    
    # 总损失 = 主损失 + 辅助损失
    total_loss = main_loss + aux_loss
    
    # 反向传播
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    print(f"主损失: {main_loss.item():.4f}")
    print(f"总损失: {total_loss.item():.4f}")
    
    # 打印模型参数数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"总参数量: {total_params:,}")

KeyboardInterrupt: 