In [None]:
# multimodal_stimulus_fmri_predict/classifiers/vision_transformer.py
import torch
import torch.nn as nn
from transformers import ViTModel, ViTConfig
from ..core.base_classifier import BaseClassifier

class VisionTransformerClassifier(BaseClassifier):
    """Vision Transformer classifier for image-based fMRI prediction."""
    
    def build_model(self) -> nn.Module:
        """Build ViT model with custom classification head."""
        config = ViTConfig(
            image_size=self.config.get('image_size', 224),
            patch_size=self.config.get('patch_size', 16),
            num_labels=self.config.get('num_classes', 2),
            hidden_size=self.config.get('hidden_size', 768),
            num_hidden_layers=self.config.get('num_layers', 12),
            num_attention_heads=self.config.get('num_heads', 12)
        )
        
        if self.config.get('pretrained', True):
            model = ViTModel.from_pretrained('google/vit-base-patch16-224')
        else:
            model = ViTModel(config)
        
        # Custom classification head
        classifier = nn.Sequential(
            nn.Linear(model.config.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, self.config.get('num_classes', 2))
        )
        
        return VisionTransformerWrapper(model, classifier)
    
    def preprocess_data(self, data: torch.Tensor) -> torch.Tensor:
        """Preprocess images for ViT."""
        # Ensure correct shape: (batch_size, channels, height, width)
        if data.dim() == 3:
            data = data.unsqueeze(0)
        if data.shape[1] != 3:
            data = data.repeat(1, 3, 1, 1) if data.shape[1] == 1 else data
        
        # Normalize to [0, 1] if needed
        if data.max() > 1.0:
            data = data / 255.0
        
        return data


class VisionTransformerWrapper(nn.Module):
    """Wrapper for ViT with custom head."""
    
    def __init__(self, vit_model: ViTModel, classifier: nn.Module):
        super().__init__()
        self.vit = vit_model
        self.classifier = classifier
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        outputs = self.vit(x)
        pooled_output = outputs.pooler_output
        return self.classifier(pooled_output)