In [None]:
# multimodal_stimulus_fmri_predict/classifiers/resnet.py
import torch
import torch.nn as nn
from torchvision import models
from ..core.base_classifier import BaseClassifier

class ResNetClassifier(BaseClassifier):
    """ResNet classifier for image-based fMRI prediction."""
    
    def build_model(self) -> nn.Module:
        """Build ResNet model with custom classification head."""
        model_name = self.config.get('model_name', 'resnet50')
        pretrained = self.config.get('pretrained', True)
        
        if model_name == 'resnet18':
            model = models.resnet18(pretrained=pretrained)
        elif model_name == 'resnet34':
            model = models.resnet34(pretrained=pretrained)
        elif model_name == 'resnet50':
            model = models.resnet50(pretrained=pretrained)
        elif model_name == 'resnet101':
            model = models.resnet101(pretrained=pretrained)
        else:
            raise ValueError(f"Unsupported ResNet variant: {model_name}")
        
        # Replace final layer
        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, self.config.get('num_classes', 2))
        )
        
        return model
    
    def preprocess_data(self, data: torch.Tensor) -> torch.Tensor:
        """Preprocess images for ResNet."""
        # Ensure RGB format
        if data.shape[1] == 1:
            data = data.repeat(1, 3, 1, 1)
        
        # Normalize using ImageNet stats
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(data.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(data.device)
        data = (data - mean) / std
        
        return data
