In [None]:

# multimodal_stimulus_fmri_predict/classifiers/efficientnet.py
import torch
import torch.nn as nn
from torchvision import models
from ..core.base_classifier import BaseClassifier

class EfficientNetClassifier(BaseClassifier):
    """EfficientNet classifier for image-based fMRI prediction."""
    
    def build_model(self) -> nn.Module:
        """Build EfficientNet model with custom classification head."""
        model_name = self.config.get('model_name', 'efficientnet_b0')
        pretrained = self.config.get('pretrained', True)
        
        if model_name == 'efficientnet_b0':
            model = models.efficientnet_b0(pretrained=pretrained)
        elif model_name == 'efficientnet_b1':
            model = models.efficientnet_b1(pretrained=pretrained)
        elif model_name == 'efficientnet_b2':
            model = models.efficientnet_b2(pretrained=pretrained)
        else:
            raise ValueError(f"Unsupported EfficientNet variant: {model_name}")
        
        # Replace classifier
        num_features = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, self.config.get('num_classes', 2))
        )
        
        return model
    
    def preprocess_data(self, data: torch.Tensor) -> torch.Tensor:
        """Preprocess images for EfficientNet."""
        # Ensure RGB format
        if data.shape[1] == 1:
            data = data.repeat(1, 3, 1, 1)
        
        # Normalize using ImageNet stats
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(data.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(data.device)
        data = (data - mean) / std
        
        return data
