In [None]:

# multimodal_stimulus_fmri_predict/classifiers/multimodal_classifier.py
import torch
import torch.nn as nn
from typing import Dict, Tuple
from ..core.base_classifier import BaseClassifier

class MultiModalClassifier(BaseClassifier):
    """Multi-modal classifier combining image and fMRI features."""
    
    def build_model(self) -> nn.Module:
        """Build multi-modal fusion model."""
        return MultiModalFusionModel(
            image_backbone=self.config.get('image_backbone', 'resnet50'),
            fmri_input_dim=self.config.get('fmri_input_dim', 1000),
            fusion_dim=self.config.get('fusion_dim', 256),
            num_classes=self.config.get('num_classes', 2)
        )
    
    def preprocess_data(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Preprocess both image and fMRI data."""
        image_data, fmri_data = data
        
        # Preprocess images
        if image_data.shape[1] == 1:
            image_data = image_data.repeat(1, 3, 1, 1)
        
        # Normalize images
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(image_data.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(image_data.device)
        image_data = (image_data - mean) / std
        
        # Normalize fMRI data
        fmri_data = (fmri_data - fmri_data.mean(dim=1, keepdim=True)) / fmri_data.std(dim=1, keepdim=True)
        
        return image_data, fmri_data


class MultiModalFusionModel(nn.Module):
    """Multi-modal fusion model architecture."""
    
    def __init__(self, image_backbone: str, fmri_input_dim: int, 
                 fusion_dim: int, num_classes: int):
        super().__init__()
        
        # Image encoder
        if image_backbone == 'resnet50':
            self.image_encoder = models.resnet50(pretrained=True)
            self.image_encoder.fc = nn.Identity()
            image_feat_dim = 2048
        elif image_backbone == 'efficientnet_b0':
            self.image_encoder = models.efficientnet_b0(pretrained=True)
            self.image_encoder.classifier = nn.Identity()
            image_feat_dim = 1280
        else:
            raise ValueError(f"Unsupported backbone: {image_backbone}")
        
        # fMRI encoder  
        self.fmri_encoder = nn.Sequential(
            nn.Linear(fmri_input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(image_feat_dim + 256, fusion_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim, fusion_dim//2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim//2, num_classes)
        )
    
    def forward(self, data: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        image_data, fmri_data = data
        
        # Extract features
        image_features = self.image_encoder(image_data)
        fmri_features = self.fmri_encoder(fmri_data)
        
        # Fuse features
        combined_features = torch.cat([image_features, fmri_features], dim=1)
        output = self.fusion(combined_features)
        
        return output

