In [None]:

# multimodal_stimulus_fmri_predict/core/base_classifier.py
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, Tuple, List
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import logging

class BaseClassifier(ABC):
    """Abstract base class for all classifiers."""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.model = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.logger = logging.getLogger(self.__class__.__name__)
        
    @abstractmethod
    def build_model(self) -> nn.Module:
        """Build and return the model architecture."""
        pass
    
    @abstractmethod
    def preprocess_data(self, data: Any) -> Any:
        """Preprocess input data for the specific classifier."""
        pass
    
    def train(self, train_loader: DataLoader, val_loader: DataLoader, 
              epochs: int = 10) -> Dict[str, List[float]]:
        """Train the classifier."""
        if self.model is None:
            self.model = self.build_model()
        
        self.model.to(self.device)
        optimizer = torch.optim.Adam(self.model.parameters(), 
                                   lr=self.config.get('learning_rate', 1e-4))
        criterion = nn.CrossEntropyLoss()
        
        history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
        
        for epoch in range(epochs):
            # Training phase
            self.model.train()
            train_loss = 0.0
            for batch_idx, (data, targets) in enumerate(train_loader):
                data, targets = data.to(self.device), targets.to(self.device)
                data = self.preprocess_data(data)
                
                optimizer.zero_grad()
                outputs = self.model(data)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
            
            # Validation phase
            val_loss, val_acc = self.evaluate(val_loader)
            history['train_loss'].append(train_loss / len(train_loader))
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            
            self.logger.info(f'Epoch {epoch+1}/{epochs}: '
                           f'Train Loss: {train_loss/len(train_loader):.4f}, '
                           f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        return history
    
    def evaluate(self, data_loader: DataLoader) -> Tuple[float, float]:
        """Evaluate the classifier."""
        self.model.eval()
        total_loss = 0.0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for data, targets in data_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                data = self.preprocess_data(data)
                
                outputs = self.model(data)
                loss = nn.CrossEntropyLoss()(outputs, targets)
                total_loss += loss.item()
                
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        accuracy = accuracy_score(all_targets, all_preds)
        return total_loss / len(data_loader), accuracy
    
    def predict(self, data: torch.Tensor) -> np.ndarray:
        """Make predictions on new data."""
        self.model.eval()
        with torch.no_grad():
            data = data.to(self.device)
            data = self.preprocess_data(data)
            outputs = self.model(data)
            return torch.softmax(outputs, dim=1).cpu().numpy()