# Tutorial 4: Building Custom Models and Decoders

**Level**: Intermediate  
**Time**: 30-40 minutes  
**Prerequisites**: Tutorial 1, Tutorial 2

## Overview

In this tutorial, you'll learn how to:

1. **Extend BaseModel** - Create custom model classes
2. **Implement Training Logic** - Custom training loops and optimization
3. **Build Neural Decoders** - From scratch using PyTorch
4. **Integrate with Pipelines** - Use custom models in neurOS pipelines
5. **Save & Load Models** - Model persistence and versioning

## Key Concepts

- **BaseModel Interface**: neurOS's model abstraction
- **Custom Architectures**: Building domain-specific models
- **Pipeline Integration**: Seamless model swapping
- **Hyperparameter Tuning**: Systematic optimization

Let's get started!

---

## Section 1: Understanding the BaseModel Interface

All neurOS models extend `BaseModel`, which provides a consistent interface for training, prediction, and serialization.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, Optional

from neuros.models import BaseModel
from neuros.drivers import MockDriver
from neuros.pipeline import Pipeline

# Display the BaseModel interface
print("BaseModel Methods:")
for method in dir(BaseModel):
    if not method.startswith('_'):
        print(f"  - {method}")

### Key Methods to Implement

When creating a custom model, you must implement:

1. **`train(X, y, **kwargs)`** - Training logic
2. **`predict(X, **kwargs)`** - Inference logic
3. **`save(path)`** - Serialization
4. **`load(path)`** - Deserialization (class method)

Optional methods:
- **`predict_proba(X)`** - Probability estimates
- **`score(X, y)`** - Evaluation metric
- **`get_params()`** - Hyperparameters
- **`set_params(**params)`** - Update hyperparameters

---

## Section 2: Example 1 - Simple Threshold Decoder

Let's start with a simple threshold-based decoder for motor imagery classification.

In [None]:
class ThresholdDecoder(BaseModel):
    """
    Simple threshold-based decoder.
    
    Classifies based on whether the mean signal amplitude
    across specified channels exceeds a learned threshold.
    """
    
    def __init__(self, channels: Optional[list] = None):
        super().__init__()
        self.channels = channels  # Which channels to use
        self.threshold = None  # Learned threshold
        self.is_trained = False
    
    def train(self, X: np.ndarray, y: np.ndarray, **kwargs) -> 'ThresholdDecoder':
        """
        Learn optimal threshold from training data.
        
        Parameters
        ----------
        X : ndarray, shape (n_samples, n_features)
            Training features
        y : ndarray, shape (n_samples,)
            Training labels (0 or 1)
        """
        # Use all features if channels not specified
        if self.channels is None:
            self.channels = list(range(X.shape[1]))
        
        # Extract relevant features
        X_subset = X[:, self.channels]
        
        # Compute mean amplitude for each sample
        amplitudes = np.mean(np.abs(X_subset), axis=1)
        
        # Find threshold that best separates classes
        # Simple approach: midpoint between class means
        class_0_mean = np.mean(amplitudes[y == 0])
        class_1_mean = np.mean(amplitudes[y == 1])
        self.threshold = (class_0_mean + class_1_mean) / 2
        
        self.is_trained = True
        
        print(f"Threshold learned: {self.threshold:.4f}")
        print(f"  Class 0 mean: {class_0_mean:.4f}")
        print(f"  Class 1 mean: {class_1_mean:.4f}")
        
        return self
    
    def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
        """
        Predict class labels.
        """
        if not self.is_trained:
            raise ValueError("Model must be trained before prediction")
        
        # Extract relevant features
        X_subset = X[:, self.channels]
        
        # Compute amplitudes
        amplitudes = np.mean(np.abs(X_subset), axis=1)
        
        # Threshold classification
        predictions = (amplitudes > self.threshold).astype(int)
        
        return predictions
    
    def save(self, path: str) -> None:
        """Save model to disk."""
        import pickle
        model_data = {
            'channels': self.channels,
            'threshold': self.threshold,
            'is_trained': self.is_trained
        }
        with open(path, 'wb') as f:
            pickle.dump(model_data, f)
    
    @classmethod
    def load(cls, path: str) -> 'ThresholdDecoder':
        """Load model from disk."""
        import pickle
        with open(path, 'rb') as f:
            model_data = pickle.load(f)
        
        model = cls(channels=model_data['channels'])
        model.threshold = model_data['threshold']
        model.is_trained = model_data['is_trained']
        return model

print("✓ ThresholdDecoder class defined")

### Test the Threshold Decoder

In [None]:
# Generate synthetic data
np.random.seed(42)
n_samples = 200
n_features = 10

# Class 0: Low amplitude
X_class0 = np.random.randn(n_samples // 2, n_features) * 0.5
y_class0 = np.zeros(n_samples // 2)

# Class 1: High amplitude
X_class1 = np.random.randn(n_samples // 2, n_features) * 1.5 + 2.0
y_class1 = np.ones(n_samples // 2)

# Combine
X = np.vstack([X_class0, X_class1])
y = np.hstack([y_class0, y_class1])

# Shuffle
indices = np.random.permutation(n_samples)
X = X[indices]
y = y[indices]

# Split train/test
split = int(0.7 * n_samples)
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]

# Train model
model = ThresholdDecoder()
model.train(X_train, y_train)

# Predict
y_pred = model.predict(X_test)

# Evaluate
accuracy = np.mean(y_pred == y_test)
print(f"\n✓ Test Accuracy: {accuracy:.2%}")

---

## Section 3: Example 2 - Custom PyTorch Decoder

Now let's build a more sophisticated decoder using PyTorch.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

class CustomNeuralDecoder(BaseModel):
    """
    Custom neural network decoder with flexible architecture.
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dims: list = [64, 32],
        n_classes: int = 2,
        dropout: float = 0.3,
        learning_rate: float = 0.001
    ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.n_classes = n_classes
        self.dropout = dropout
        self.learning_rate = learning_rate
        
        # Build network
        self.network = self._build_network()
        self.optimizer = None
        self.is_trained = False
    
    def _build_network(self) -> nn.Module:
        """Build the neural network architecture."""
        layers = []
        
        # Input layer
        prev_dim = self.input_dim
        
        # Hidden layers
        for hidden_dim in self.hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.Dropout(self.dropout))
            prev_dim = hidden_dim
        
        # Output layer
        layers.append(nn.Linear(prev_dim, self.n_classes))
        
        return nn.Sequential(*layers)
    
    def train(
        self,
        X: np.ndarray,
        y: np.ndarray,
        epochs: int = 50,
        batch_size: int = 32,
        validation_split: float = 0.2,
        verbose: bool = True,
        **kwargs
    ) -> 'CustomNeuralDecoder':
        """
        Train the neural decoder.
        """
        # Convert to tensors
        X_tensor = torch.FloatTensor(X)
        y_tensor = torch.LongTensor(y)
        
        # Create validation split
        n_val = int(len(X) * validation_split)
        indices = torch.randperm(len(X))
        
        val_indices = indices[:n_val]
        train_indices = indices[n_val:]
        
        X_train, y_train = X_tensor[train_indices], y_tensor[train_indices]
        X_val, y_val = X_tensor[val_indices], y_tensor[val_indices]
        
        # Create data loaders
        train_dataset = TensorDataset(X_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
        # Setup optimizer
        self.optimizer = optim.Adam(self.network.parameters(), lr=self.learning_rate)
        criterion = nn.CrossEntropyLoss()
        
        # Training history
        history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
        
        # Training loop
        for epoch in range(epochs):
            self.network.train()
            train_loss = 0.0
            
            for batch_X, batch_y in train_loader:
                self.optimizer.zero_grad()
                outputs = self.network(batch_X)
                loss = criterion(outputs, batch_y)
                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()
            
            train_loss /= len(train_loader)
            
            # Validation
            self.network.eval()
            with torch.no_grad():
                val_outputs = self.network(X_val)
                val_loss = criterion(val_outputs, y_val).item()
                val_pred = torch.argmax(val_outputs, dim=1)
                val_acc = (val_pred == y_val).float().mean().item()
            
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            
            if verbose and (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs} - "
                      f"Train Loss: {train_loss:.4f}, "
                      f"Val Loss: {val_loss:.4f}, "
                      f"Val Acc: {val_acc:.4f}")
        
        self.is_trained = True
        self.training_history = history
        
        return self
    
    def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
        """Predict class labels."""
        if not self.is_trained:
            raise ValueError("Model must be trained before prediction")
        
        self.network.eval()
        with torch.no_grad():
            X_tensor = torch.FloatTensor(X)
            outputs = self.network(X_tensor)
            predictions = torch.argmax(outputs, dim=1).numpy()
        
        return predictions
    
    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """Predict class probabilities."""
        if not self.is_trained:
            raise ValueError("Model must be trained before prediction")
        
        self.network.eval()
        with torch.no_grad():
            X_tensor = torch.FloatTensor(X)
            outputs = self.network(X_tensor)
            probas = torch.softmax(outputs, dim=1).numpy()
        
        return probas
    
    def save(self, path: str) -> None:
        """Save model to disk."""
        torch.save({
            'network_state': self.network.state_dict(),
            'optimizer_state': self.optimizer.state_dict() if self.optimizer else None,
            'config': {
                'input_dim': self.input_dim,
                'hidden_dims': self.hidden_dims,
                'n_classes': self.n_classes,
                'dropout': self.dropout,
                'learning_rate': self.learning_rate
            },
            'is_trained': self.is_trained
        }, path)
    
    @classmethod
    def load(cls, path: str) -> 'CustomNeuralDecoder':
        """Load model from disk."""
        checkpoint = torch.load(path)
        config = checkpoint['config']
        
        model = cls(**config)
        model.network.load_state_dict(checkpoint['network_state'])
        
        if checkpoint['optimizer_state']:
            model.optimizer = optim.Adam(model.network.parameters())
            model.optimizer.load_state_dict(checkpoint['optimizer_state'])
        
        model.is_trained = checkpoint['is_trained']
        return model

print("✓ CustomNeuralDecoder class defined")

### Train and Evaluate the Neural Decoder

In [None]:
# Create model
neural_model = CustomNeuralDecoder(
    input_dim=n_features,
    hidden_dims=[32, 16],
    n_classes=2,
    dropout=0.2,
    learning_rate=0.01
)

# Train
print("Training neural decoder...\n")
neural_model.train(X_train, y_train, epochs=50, batch_size=16, verbose=True)

# Predict
y_pred_neural = neural_model.predict(X_test)
y_proba_neural = neural_model.predict_proba(X_test)

# Evaluate
accuracy_neural = np.mean(y_pred_neural == y_test)
print(f"\n✓ Neural Decoder Test Accuracy: {accuracy_neural:.2%}")

### Visualize Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

# Plot loss
ax1.plot(neural_model.training_history['train_loss'], label='Train Loss', linewidth=2)
ax1.plot(neural_model.training_history['val_loss'], label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot accuracy
ax2.plot(neural_model.training_history['val_acc'], label='Val Accuracy', linewidth=2, color='green')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Section 4: Pipeline Integration

Let's integrate our custom models into neurOS pipelines.

In [None]:
# Create pipeline with custom model
driver = MockDriver(n_channels=n_features, sampling_rate=250)

# Option 1: Use threshold decoder
pipeline_threshold = Pipeline(driver=driver, model=model)

# Option 2: Use neural decoder
pipeline_neural = Pipeline(driver=driver, model=neural_model)

print("✓ Pipelines created with custom models")
print(f"  - Threshold Decoder Pipeline: {pipeline_threshold}")
print(f"  - Neural Decoder Pipeline: {pipeline_neural}")

### Compare Model Performance

In [None]:
# Evaluate both models
models = {
    'Threshold Decoder': model,
    'Neural Decoder': neural_model
}

results = {}
for name, mdl in models.items():
    y_pred = mdl.predict(X_test)
    accuracy = np.mean(y_pred == y_test)
    results[name] = accuracy
    print(f"{name}: {accuracy:.2%} accuracy")

# Visualization
plt.figure(figsize=(8, 5))
bars = plt.bar(results.keys(), results.values(), color=['steelblue', 'coral'])
plt.ylabel('Accuracy')
plt.title('Model Comparison')
plt.ylim([0, 1])
plt.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.2%}',
             ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

---

## Section 5: Model Persistence

Save and load models for reuse.

In [None]:
import tempfile
import os

# Create temporary directory for models
model_dir = tempfile.mkdtemp()

# Save models
threshold_path = os.path.join(model_dir, 'threshold_decoder.pkl')
neural_path = os.path.join(model_dir, 'neural_decoder.pth')

model.save(threshold_path)
neural_model.save(neural_path)

print(f"✓ Models saved to {model_dir}")
print(f"  - Threshold: {threshold_path}")
print(f"  - Neural: {neural_path}")

# Load models
loaded_threshold = ThresholdDecoder.load(threshold_path)
loaded_neural = CustomNeuralDecoder.load(neural_path)

# Verify loaded models work
y_pred_loaded_threshold = loaded_threshold.predict(X_test)
y_pred_loaded_neural = loaded_neural.predict(X_test)

# Check they produce same results
assert np.array_equal(y_pred, y_pred_loaded_threshold), "Threshold model changed after loading!"
assert np.array_equal(y_pred_neural, y_pred_loaded_neural), "Neural model changed after loading!"

print("\n✓ Models loaded successfully and produce identical predictions")

---

## Section 6: Advanced Example - Time-Series Decoder

Build a decoder that handles sequential/temporal data.

In [None]:
class TemporalDecoder(BaseModel):
    """
    LSTM-based decoder for temporal sequences.
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 64,
        n_classes: int = 2,
        n_layers: int = 2,
        dropout: float = 0.3
    ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_classes = n_classes
        self.n_layers = n_layers
        self.dropout = dropout
        
        # Build LSTM network
        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            n_layers,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0
        )
        self.fc = nn.Linear(hidden_dim, n_classes)
        self.is_trained = False
    
    def train(
        self,
        X: np.ndarray,
        y: np.ndarray,
        epochs: int = 30,
        batch_size: int = 16,
        learning_rate: float = 0.001,
        **kwargs
    ) -> 'TemporalDecoder':
        """
        Train the temporal decoder.
        
        Parameters
        ----------
        X : ndarray, shape (n_samples, sequence_length, n_features)
            Sequential input data
        y : ndarray, shape (n_samples,)
            Labels
        """
        # Ensure 3D input
        if X.ndim == 2:
            X = X[:, np.newaxis, :]  # Add sequence dimension
        
        X_tensor = torch.FloatTensor(X)
        y_tensor = torch.LongTensor(y)
        
        dataset = TensorDataset(X_tensor, y_tensor)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(epochs):
            total_loss = 0
            for batch_X, batch_y in loader:
                optimizer.zero_grad()
                
                # Forward pass
                lstm_out, _ = self.lstm(batch_X)
                # Use last time step
                out = self.fc(lstm_out[:, -1, :])
                
                loss = criterion(out, batch_y)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            if (epoch + 1) % 10 == 0:
                avg_loss = total_loss / len(loader)
                print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        self.is_trained = True
        return self
    
    def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
        """Predict on sequential data."""
        if not self.is_trained:
            raise ValueError("Model must be trained first")
        
        if X.ndim == 2:
            X = X[:, np.newaxis, :]
        
        self.eval()
        with torch.no_grad():
            X_tensor = torch.FloatTensor(X)
            lstm_out, _ = self.lstm(X_tensor)
            out = self.fc(lstm_out[:, -1, :])
            predictions = torch.argmax(out, dim=1).numpy()
        
        return predictions
    
    def parameters(self):
        """Return model parameters."""
        return list(self.lstm.parameters()) + list(self.fc.parameters())
    
    def save(self, path: str) -> None:
        """Save model."""
        torch.save({
            'lstm_state': self.lstm.state_dict(),
            'fc_state': self.fc.state_dict(),
            'config': {
                'input_dim': self.input_dim,
                'hidden_dim': self.hidden_dim,
                'n_classes': self.n_classes,
                'n_layers': self.n_layers,
                'dropout': self.dropout
            }
        }, path)
    
    @classmethod
    def load(cls, path: str) -> 'TemporalDecoder':
        """Load model."""
        checkpoint = torch.load(path)
        model = cls(**checkpoint['config'])
        model.lstm.load_state_dict(checkpoint['lstm_state'])
        model.fc.load_state_dict(checkpoint['fc_state'])
        model.is_trained = True
        return model

print("✓ TemporalDecoder class defined")

---

## Summary

In this tutorial, you learned:

✅ **BaseModel Interface** - The foundation for all neurOS models  
✅ **Custom Decoders** - Built threshold and neural decoders from scratch  
✅ **PyTorch Integration** - Created sophisticated neural architectures  
✅ **Pipeline Integration** - Seamlessly swapped models in pipelines  
✅ **Model Persistence** - Saved and loaded models  
✅ **Temporal Models** - Handled sequential data with LSTMs  

## Next Steps

- **Tutorial 5**: Benchmarking & Performance Optimization
- **Tutorial 6**: Real-World NWB Data Integration
- **Advanced**: Hyperparameter tuning with Optuna
- **Advanced**: Multi-modal fusion architectures

## Exercises

1. **Add regularization** - Implement L1/L2 regularization in CustomNeuralDecoder
2. **Early stopping** - Add early stopping to prevent overfitting
3. **Ensemble models** - Combine multiple decoders
4. **Custom loss functions** - Implement focal loss or class-weighted loss
5. **Attention mechanism** - Add attention to TemporalDecoder

---

**Questions or feedback?** Open an issue on GitHub or check the docs at https://neuros.readthedocs.io
