# Motor Imagery Classification with NeurOS

This notebook demonstrates how to build a motor imagery BCI system using NeurOS. Motor imagery involves imagining movement (e.g., left/right hand) which produces distinct patterns in brain signals.

**What you'll learn:**
- Setting up a BCI pipeline for motor imagery
- Training an EEGNet model on synthetic motor imagery data
- Evaluating model performance
- Saving models to the registry
- Real-time classification simulation

In [None]:
# Install neuros if needed
# !pip install -e ..

import numpy as np
import matplotlib.pyplot as plt
from neuros.models import EEGNetModel, ModelRegistry
from neuros.pipeline import Pipeline
from neuros.drivers import MockDriver
from neuros.processing.filters import BandpassFilter
from neuros.processing.feature_extraction import BandPowerExtractor

print("✓ Imports successful")

## 1. Generate Synthetic Motor Imagery Data

For this demo, we'll create synthetic EEG data that mimics motor imagery patterns:
- **Class 0 (Rest):** Baseline activity
- **Class 1 (Left Hand):** Enhanced mu rhythm (8-12 Hz) over left motor cortex
- **Class 2 (Right Hand):** Enhanced mu rhythm over right motor cortex

In [None]:
def generate_motor_imagery_data(n_samples=300, n_channels=8, n_timepoints=250, fs=250.0):
    """
    Generate synthetic motor imagery EEG data.
    
    Parameters:
    -----------
    n_samples : int
        Number of trials
    n_channels : int
        Number of EEG channels
    n_timepoints : int
        Number of time points per trial (1 second at 250 Hz)
    fs : float
        Sampling frequency in Hz
    
    Returns:
    --------
    X : np.ndarray
        EEG data of shape (n_samples, n_channels, n_timepoints)
    y : np.ndarray
        Labels (0=rest, 1=left hand, 2=right hand)
    """
    X = []
    y = []
    
    t = np.arange(n_timepoints) / fs
    
    for _ in range(n_samples):
        # Randomly choose class
        label = np.random.randint(0, 3)
        
        # Base noise
        trial = np.random.randn(n_channels, n_timepoints) * 0.5
        
        if label == 1:  # Left hand imagery
            # Add mu rhythm (10 Hz) to left motor cortex (channels 0-3)
            mu_signal = 2.0 * np.sin(2 * np.pi * 10 * t)
            trial[:4, :] += mu_signal
            
        elif label == 2:  # Right hand imagery
            # Add mu rhythm to right motor cortex (channels 4-7)
            mu_signal = 2.0 * np.sin(2 * np.pi * 10 * t)
            trial[4:, :] += mu_signal
        
        # Add some 1/f noise
        for ch in range(n_channels):
            freqs = np.fft.rfftfreq(n_timepoints, 1/fs)
            fft_signal = np.fft.rfft(trial[ch])
            # 1/f spectrum
            fft_signal *= 1 / (freqs + 1)**0.5
            trial[ch] = np.fft.irfft(fft_signal, n=n_timepoints)
        
        X.append(trial)
        y.append(label)
    
    return np.array(X), np.array(y)

# Generate dataset
print("Generating synthetic motor imagery data...")
X, y = generate_motor_imagery_data(n_samples=300)
print(f"✓ Generated {len(X)} trials")
print(f"  Shape: {X.shape} (samples, channels, timepoints)")
print(f"  Labels: {np.bincount(y)} trials per class")

## 2. Visualize Sample Trials

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(12, 8))
class_names = ['Rest', 'Left Hand', 'Right Hand']
colors = ['blue', 'green', 'red']

for cls in range(3):
    # Find first trial of this class
    idx = np.where(y == cls)[0][0]
    
    # Plot all channels
    for ch in range(X.shape[1]):
        axes[cls].plot(X[idx, ch, :] + ch * 3, alpha=0.7, linewidth=0.8)
    
    axes[cls].set_title(f"Class {cls}: {class_names[cls]}", fontsize=12, fontweight='bold')
    axes[cls].set_ylabel('Channel (offset)')
    axes[cls].set_ylim(-2, X.shape[1] * 3)

axes[2].set_xlabel('Time (samples)')
plt.tight_layout()
plt.show()

print("Note: Left hand trials show enhanced activity in channels 0-3")
print("      Right hand trials show enhanced activity in channels 4-7")

## 3. Split Data into Train/Test Sets

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {len(X_train)} trials")
print(f"Test set: {len(X_test)} trials")

## 4. Train EEGNet Model

EEGNet is a compact convolutional neural network specifically designed for EEG classification tasks.

In [None]:
# Create EEGNet model
model = EEGNetModel(
    n_channels=8,
    n_classes=3,
    n_timepoints=250,
    dropout=0.5,
)

print("Training EEGNet model...")
print("(This may take 1-2 minutes)\n")

# Train model
model.train(X_train, y_train)

print("\n✓ Training complete!")

## 5. Evaluate Model Performance

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns

# Make predictions
y_pred = model.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy:.2%}\n")

# Classification report
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

## 6. Save Model to Registry

NeurOS includes a model registry for managing trained models with metadata.

In [None]:
# Initialize registry
registry = ModelRegistry()

# Save model with metadata
metadata = registry.save(
    model,
    name="motor_imagery_eegnet",
    version="1.0.0",
    metrics={
        "accuracy": float(accuracy),
        "n_train_samples": len(X_train),
        "n_test_samples": len(X_test),
    },
    hyperparameters={
        "n_channels": 8,
        "n_classes": 3,
        "dropout": 0.5,
    },
    training_info={
        "dataset": "synthetic_motor_imagery",
        "n_epochs": 100,
    },
    tags=["motor-imagery", "eegnet", "demo"],
)

print(f"✓ Model saved: {metadata.name} v{metadata.version}")
print(f"  Location: {metadata.file_path}")
print(f"  Accuracy: {metadata.metrics['accuracy']:.2%}")
print(f"  Tags: {', '.join(metadata.tags)}")

## 7. Load Model from Registry

In [None]:
# Load the saved model
loaded_model = registry.load("motor_imagery_eegnet", version="1.0.0")

# Test that it works
test_predictions = loaded_model.predict(X_test[:5])
print(f"✓ Model loaded successfully")
print(f"  Test predictions on 5 samples: {test_predictions}")
print(f"  True labels: {y_test[:5]}")

## 8. Real-Time Classification Simulation

Simulate real-time motor imagery classification using the mock driver.

In [None]:
import asyncio

# Create pipeline with trained model
pipeline = Pipeline(
    driver=MockDriver(sampling_rate=250.0, channels=8),
    model=loaded_model,
    fs=250.0,
    bands={
        "mu": (8, 12),     # Mu rhythm for motor imagery
        "beta": (12, 30),  # Beta band
    },
)

print("Running real-time pipeline for 3 seconds...")
print("(Using mock driver for demonstration)\n")

# Run pipeline
metrics = await pipeline.run(duration=3.0)

print("\n✓ Pipeline complete!")
print(f"  Throughput: {metrics['throughput']:.1f} samples/sec")
print(f"  Mean latency: {metrics['mean_latency']*1000:.2f} ms")
print(f"  Total samples: {metrics['samples']}")

## 9. Next Steps

Now that you've built a motor imagery BCI:

1. **Try with real hardware:** Replace `MockDriver` with `BrainFlowDriver` for OpenBCI, Emotiv, etc.
2. **Experiment with models:** Try `TransformerModel`, `CNNModel`, or `RandomForestModel`
3. **Tune hyperparameters:** Adjust learning rate, dropout, number of epochs
4. **Real-time feedback:** Add visual/audio feedback based on predictions
5. **Multi-modal fusion:** Combine EEG with EMG, video, or other modalities

See other notebooks:
- `02_multimodal_fusion.ipynb` - Combine multiple data sources
- `03_p300_speller.ipynb` - Build a brain-controlled keyboard
- `04_real_hardware_setup.ipynb` - Connect to physical BCI devices

In [None]:
# List all models in registry
all_models = registry.list_models()
print(f"\nModels in registry: {len(all_models)}")
for m in all_models:
    acc = m.metrics.get('accuracy', 0)
    print(f"  - {m.name} v{m.version}: {acc:.2%} accuracy")