In [None]:
import torch
import torch.nn as nn
import math
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchvision.ops import DeformConv2d as TVDeformConv2d
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10


# --------------------------- 辅助模块：位置编码 ---------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        x: [B, L, D]
        """
        B, L, D = x.shape
        assert D == self.pe.size(2), "特征维度不匹配"
        pe = self.pe[:, :L, :]
        return x + pe.expand(B, -1, -1)


# --------------------------- 可变形卷积模块 ---------------------------
class DeformConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.offset_channels = 2 * kernel_size * kernel_size
        self.offset_conv = nn.Conv2d(
            in_channels,
            self.offset_channels,
            kernel_size=3,
            padding=padding,
            bias=True
        )
        self.offset_conv.weight.data.zero_()
        self.offset_conv.bias.data.zero_()

        self.dcn = TVDeformConv2d(
            in_channels, out_channels, kernel_size, padding=padding
        )

    def forward(self, x):
        offset = self.offset_conv(x)  # [B, 18, H, W]
        return self.dcn(x, offset)


# --------------------------- Transformer 单元 ---------------------------
class TransformerUnit(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, num_heads=2, batch_first=True)
        self.pos_enc = PositionalEncoding(d_model)
        self.norm = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x):
        is_image = (x.dim() == 4)
        if is_image:
            B, C, H, W = x.shape
            x = x.flatten(2).permute(0, 2, 1)  # [B, H*W, C]

        x = self.pos_enc(x)

        attn_out, _ = self.attention(x, x, x)
        x = x + attn_out
        x = x + self.ffn(self.norm(x))

        if is_image:
            x = x.permute(0, 2, 1).view(B, C, H, W)

        return x


# --------------------------- 层级融合模块 ---------------------------

# --------------------------- 新增模块 ---------------------------
class AdaptiveFusion(nn.Module):
     def __init__(self, channels):
         super().__init__()
         self.channel_att = ChannelAttention(channels)
         self.spatial_att = SpatialAttention()
         self.global_weight = nn.Parameter(torch.tensor(0.5))
         self.local_weight = nn.Parameter(torch.tensor(0.5))
         self.dw_conv = nn.Conv2d(channels, channels, 3, 
                               padding=1, groups=channels)
 
     def forward(self, local, global_feat):
         # 通道注意力增强局部特征
         local_att = self.channel_att(local) * local
         # 空间注意力增强全局特征
         global_att = self.spatial_att(global_feat) * global_feat
         # 动态权重融合
         w_g = torch.sigmoid(self.global_weight)
         w_l = torch.sigmoid(self.local_weight)
         fused = self.dw_conv(w_g*global_att + w_l*local_att)
         return fused

# --------------------------- 修改层级融合模块 ---------------------------
class HierarchicalFusion(nn.Module):
    def __init__(self, in_channels, d_model):
        super().__init__()
        self.local_branch = nn.Sequential(
            DeformConv2d(in_channels, d_model, 3, padding=1),
            nn.BatchNorm2d(d_model),
            nn.ReLU(inplace=True)
        )
        self.global_branch = nn.Sequential(
            nn.Conv2d(in_channels, d_model, 1),
            nn.BatchNorm2d(d_model),
            nn.ReLU(inplace=True),
            TransformerUnit(d_model)
        )

        self.adaptive_fusion = AdaptiveFusion(d_model)

    def forward(self, x):
        local = self.local_branch(x)
        global_feat = self.global_branch(x)

        return self.adaptive_fusion(local, global_feat)

# --------------------------- 新增注意力子模块 ---------------------------
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=4):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, max(8, channels//reduction)),
            nn.ReLU(),
            nn.Linear(max(8, channels//reduction), channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c = x.size()[:2]
        avg = self.avg_pool(x).view(b, c)
        max_val = self.max_pool(x).view(b, c)
        weight = self.fc(avg + max_val).view(b, c, 1, 1)
        return x * weight

class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(2, 1, 3, padding=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        return x * self.conv(torch.cat([avg_out, max_out], dim=1))



# --------------------------- AMFT 主干模型 ---------------------------
class AMFT(nn.Module):
    def __init__(self, num_classes=7, alpha=0.75, d_model=128):
        super().__init__()

        # 使用alpha参数构建MobileNetV2主干
        self.backbone = self._build_mobilenet_backbone(alpha)
        fusion_indices = [3, 6, 9]  # 在指定层后添加融合点
        fusion_channels = [self._get_output_channels(self.backbone[idx]) for idx in fusion_indices]

        self.fusion_layers = nn.ModuleList([
            HierarchicalFusion(c, d_model) for c in fusion_channels
        ])

        self.fusion_linear = nn.Linear(len(fusion_indices), len(fusion_indices))

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(d_model, num_classes)
        )

    def _build_mobilenet_backbone(self, alpha):
        """根据alpha参数构建并缩放MobileNetV2主干"""
        backbone = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1).features
        # 移除索引10及之后的层（高层特征模拟器）
        backbone = nn.Sequential(*list(backbone.children())[:10])  # 保留0-9层
        
        def scale_channels(channels):
            return max(8, int(math.ceil(channels * alpha / 8)) * 8)
        
        for name, layer in backbone.named_children():
            if isinstance(layer, nn.Conv2d):
                original_out_channels = layer.out_channels
                layer.out_channels = scale_channels(original_out_channels)
            elif hasattr(layer, 'conv'):
                for sub_layer_name, sub_layer in layer.conv.named_children():
                    if isinstance(sub_layer, nn.Conv2d):
                        original_in_channels = sub_layer.in_channels
                        original_out_channels = sub_layer.out_channels
                        sub_layer.in_channels = scale_channels(original_in_channels)
                        sub_layer.out_channels = scale_channels(original_out_channels)
        return backbone
    
    @staticmethod
    def _get_output_channels(layer):
        if hasattr(layer, 'out_channels'):
            return layer.out_channels
        elif hasattr(layer, 'conv') and isinstance(layer.conv[-1], nn.Conv2d):
            return layer.conv[-1].out_channels
        else:
            raise ValueError("无法确定层的输出通道数")

    def forward(self, x):
        fusion_outputs = []
        current_fusion_idx = 0

        for i, layer in enumerate(self.backbone):
            x = layer(x)
            if i in [3, 6, 9]:  # 在指定层后执行融合
                fused = self.fusion_layers[current_fusion_idx](x)
                fusion_outputs.append(fused)
                current_fusion_idx += 1

        final_feat = self._aggregate_features(fusion_outputs)
        return self.classifier(final_feat)

    def _aggregate_features(self, features):
        resized_features = []
        target_size = features[-1].shape[-2:]

        for feat in features:
            if feat.shape[-2:] != target_size:
                feat = nn.functional.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            resized_features.append(feat)

        spatial_att = torch.cat([f.mean(dim=1, keepdim=True) for f in resized_features], dim=1)
        spatial_att = nn.AdaptiveAvgPool2d((1, 1))(spatial_att)
        spatial_att = spatial_att.view(spatial_att.size(0), -1)

        spatial_att = self.fusion_linear(spatial_att)
        spatial_att = nn.Softmax(dim=1)(spatial_att)

        spatial_att = spatial_att.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)  # [B, N, 1, 1, 1]
        spatial_att = spatial_att.permute(1, 0, 2, 3, 4)  # [N, B, 1, 1, 1]

        stacked = torch.stack(resized_features)  # [N, B, C, H, W]

        weighted_sum = torch.sum(stacked * spatial_att, dim=0)  # [B, C, H, W]

        return weighted_sum