In [1]:
import import_ipynb
import torch
import torch.nn as nn
from moe_encoders import ImageMoE, TextMoE 
from cross_attention import CrossAttention
from TextDecoder import TextDecoder

ModuleNotFoundError: No module named 'moe_encoders'

In [None]:
class DualTowerModel(nn.Module):
    def __init__(self, vocab_size, output_dim=1024, n_head=8, num_classes=10, max_text_length=16):
        super().__init__()
        # 增加图像backbone的复杂度
        self.image_tower = ImageMoE(
            img_size=32,
            patch_size=4,
            in_channels=3,
            embed_dim=1024,
            num_experts=16,  # 增加专家数量
            top_k=4  # 增加每个token使用的专家数
        )
        
        # 增加文本编码器的复杂度
        self.text_tower = TextMoE(
            vocab_size,
            seq_length=16,
            embed_dim=1024,
            num_experts=16,
            top_k=4
        )
        
        # 添加更复杂的跨模态融合层
        self.fusion_layer = nn.Sequential(
            nn.Linear(output_dim * 2, output_dim * 2),
            nn.LayerNorm(output_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(output_dim * 2, output_dim),
            nn.LayerNorm(output_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

In [None]:
def forward(self, images, input_ids, attention_mask):
        # 获取基础特征
        first_image_output, second_image_output, image_feature_vector, image_cls, (first_expert_outputs, second_expert_outputs), (first_gating_output, second_gating_output) = self.image_tower(images)
        first_text_output, second_text_output, text_feature_vector, text_cls = self.text_tower(input_ids, attention_mask)
        
        # 跨模态注意力
        img2text_features = self.img2text_attention(text_feature_vector, image_feature_vector)
        text2img_features = self.text2img_attention(image_feature_vector, text_feature_vector)
        
        # 特征融合
        fused_features = self.fusion_layer(torch.cat([img2text_features, text2img_features], dim=-1))
        
        # 解码文本
        text_reconstruction = self.text_decoder(fused_features, input_ids)
        
        # 分类预测
        fused_cls = self.classifier(fused_features)
        
        return (
            image_feature_vector, text_feature_vector, 
            image_cls, text_cls, fused_cls,
            text_reconstruction,
            (first_expert_outputs, second_expert_outputs), 
            (first_gating_output, second_gating_output)
        )