# EEG Motor Imagery Classification using CNNs

## 🧠 Complete Guide to EEG Signal Processing and Classification

### 🎯 What You'll Learn:
1. **EEG Signal Processing**: Understanding temporal brain signals
2. **PhysioNet Motor Imagery Dataset**: Loading and preprocessing
3. **Channel Selection**: Identifying optimal EEG electrodes
4. **CNN for Temporal Data**: 1D and 2D convolutions for EEG
5. **Motor Imagery Classification**: Left hand vs Right hand vs Feet vs Tongue
6. **Advanced Techniques**: Spectral features, filtering, artifact removal

---

## 🧠 Understanding EEG and Motor Imagery

### What is EEG?
**Electroencephalography (EEG)** measures electrical activity of the brain using electrodes placed on the scalp.

### Motor Imagery
**Motor Imagery** is the mental rehearsal of motor actions without actual movement. When you imagine moving your hand, specific brain regions activate, creating detectable EEG patterns.

### Key EEG Concepts:
- **Channels**: 64 electrodes positioned according to 10-10 system
- **Sampling Rate**: 160 Hz (160 samples per second)
- **Frequency Bands**: 
  - Delta (0.5-4 Hz): Deep sleep
  - Theta (4-8 Hz): Drowsiness
  - Alpha (8-13 Hz): Relaxed awareness
  - Beta (13-30 Hz): Active thinking
  - Gamma (30-100 Hz): High-level cognitive functions

### Motor Imagery Tasks:
1. **Left Hand**: Imagining left hand movement
2. **Right Hand**: Imagining right hand movement  
3. **Feet**: Imagining foot movement
4. **Tongue**: Imagining tongue movement

---

## 🔧 Why CNNs for EEG?

### Traditional Approach vs CNN:
- **Traditional**: Manual feature extraction → Classical ML
- **CNN**: Automatic feature learning from raw signals

### CNN Advantages for EEG:
1. **Temporal Patterns**: 1D convolutions capture time-series patterns
2. **Spatial Patterns**: 2D convolutions capture electrode relationships
3. **Automatic Features**: No manual feature engineering needed
4. **Hierarchical Learning**: Low-level → High-level patterns

---

## 🚀 Let's Begin!

In [1]:
# Cell 1: Import Essential Libraries for EEG Processing

print("🧠 Setting up EEG Motor Imagery Classification Environment")
print("=" * 60)

# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal
from scipy.stats import zscore
import warnings
warnings.filterwarnings('ignore')

# MNE for EEG processing
import mne
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf

# Deep Learning
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.utils import to_categorical

# Machine Learning
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest, f_classif

# Visualization
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# Utilities
import os
import datetime
import pickle
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Configure matplotlib
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

# MNE configuration
mne.set_log_level('WARNING')  # Reduce MNE verbosity

print("✅ All libraries imported successfully!")
print(f"📊 TensorFlow version: {tf.__version__}")
print(f"🧠 MNE version: {mne.__version__}")
print(f"🔢 NumPy version: {np.__version__}")
print("🚀 Ready for EEG signal processing!")

🧠 Setting up EEG Motor Imagery Classification Environment


ModuleNotFoundError: No module named 'mne'

In [None]:
# Cell 2: Load PhysioNet Motor Imagery Dataset

print("📥 Loading PhysioNet EEG Motor Imagery Dataset")
print("=" * 50)

def load_physionet_data(subject_ids=[1, 2, 3], runs=[6, 10, 14], verbose=True):
    """
    Load PhysioNet EEG Motor Imagery data using MNE.
    
    Parameters:
    -----------
    subject_ids : list
        List of subject IDs to load (1-109)
    runs : list  
        List of runs to load:
        - Run 6: Left hand vs right hand motor imagery
        - Run 10: Left hand vs right hand motor imagery  
        - Run 14: Feet vs tongue motor imagery
    
    Returns:
    --------
    raw_data : list
        List of MNE Raw objects
    """
    
    raw_files = []
    
    for subject_id in subject_ids:
        if verbose:
            print(f"\n👤 Loading Subject {subject_id}...")
        
        subject_runs = []
        
        for run in runs:
            try:
                # Download data files for this subject and run
                files = eegbci.load_data(subject_id, runs=[run], update_path=False)
                
                # Load the EDF file
                raw = read_raw_edf(files[0], preload=True, stim_channel='auto')
                
                # Set standard electrode montage
                eegbci.standardize(raw)  # Convert to standard channel names
                montage = make_standard_montage('standard_1005')
                raw.set_montage(montage, match_case=False)
                
                # Add subject and run info
                raw.info['subject_info'] = {'id': subject_id, 'run': run}
                
                subject_runs.append(raw)
                
                if verbose:
                    print(f"  ✅ Run {run}: {len(raw.times)} samples, {len(raw.ch_names)} channels")
                    
            except Exception as e:
                if verbose:
                    print(f"  ❌ Run {run}: Failed to load - {e}")
                continue
        
        if subject_runs:
            # Concatenate runs for this subject
            raw_concat = concatenate_raws(subject_runs)
            raw_files.append(raw_concat)
    
    return raw_files

# Load data for first 3 subjects (you can increase this later)
print("🔄 Starting data download...")
print("Note: First download may take several minutes")

# Motor imagery runs:
# Run 6: Left hand vs right hand imagery (1 vs 2)
# Run 10: Left hand vs right hand imagery (1 vs 2) 
# Run 14: Feet vs tongue imagery (3 vs 4)
raw_data = load_physionet_data(subject_ids=[1, 2, 3], runs=[6, 10, 14])

print(f"\n✅ Successfully loaded data for {len(raw_data)} subjects")

# Display dataset information
if raw_data:
    sample_raw = raw_data[0]
    print(f"\n📊 Dataset Information:")
    print(f"Sampling frequency: {sample_raw.info['sfreq']} Hz")
    print(f"Number of channels: {len(sample_raw.ch_names)}")
    print(f"Channel types: {set(sample_raw.get_channel_types())}")
    print(f"Duration: {sample_raw.times[-1]:.1f} seconds")
    
    # Show channel names
    eeg_channels = mne.pick_types(sample_raw.info, eeg=True)
    eeg_ch_names = [sample_raw.ch_names[i] for i in eeg_channels]
    print(f"\n🧠 EEG Channels ({len(eeg_ch_names)}): {eeg_ch_names[:10]}...")
    
    print("\n🎯 Motor Imagery Task Mapping:")
    print("Event ID 1: Left Hand Imagery")
    print("Event ID 2: Right Hand Imagery")
    print("Event ID 3: Feet Imagery")
    print("Event ID 4: Tongue Imagery")

else:
    print("❌ No data loaded successfully. Check your internet connection.")

In [None]:
# Cell 3: EEG Signal Preprocessing and Visualization

print("🔧 EEG Signal Preprocessing Pipeline")
print("=" * 40)

def preprocess_eeg_data(raw_data, l_freq=7., h_freq=30., notch_freq=50., 
                       tmin=-1., tmax=4., baseline=(None, 0), verbose=True):
    """
    Comprehensive EEG preprocessing pipeline.
    
    Parameters:
    -----------
    raw_data : list of mne.Raw
        Raw EEG data from multiple subjects
    l_freq : float
        Low-pass filter frequency (Hz)
    h_freq : float  
        High-pass filter frequency (Hz)
    notch_freq : float
        Notch filter frequency for power line noise (Hz)
    tmin, tmax : float
        Time window around events (seconds)
    baseline : tuple
        Baseline correction period
        
    Returns:
    --------
    epochs_list : list
        Preprocessed epochs for each subject
    """
    
    epochs_list = []
    
    for i, raw in enumerate(raw_data):
        if verbose:
            print(f"\n🔄 Processing Subject {i+1}...")
        
        # Make a copy to avoid modifying original data
        raw_copy = raw.copy()
        
        # 1. Filter the data
        if verbose:
            print(f"  📶 Applying filters: {l_freq}-{h_freq} Hz bandpass + {notch_freq} Hz notch")
        
        # Bandpass filter (remove low-frequency drifts and high-frequency noise)
        raw_copy.filter(l_freq=l_freq, h_freq=h_freq, method='iir', verbose=False)
        
        # Notch filter (remove power line interference)
        raw_copy.notch_filter(freqs=notch_freq, verbose=False)
        
        # 2. Extract events
        try:
            events, event_id = mne.events_from_annotations(raw_copy)
            if verbose:
                print(f"  🎯 Found {len(events)} events: {event_id}")
        except Exception as e:
            if verbose:
                print(f"  ❌ Could not extract events: {e}")
            continue
        
        # 3. Select only EEG channels
        picks = mne.pick_types(raw_copy.info, eeg=True, exclude='bads')
        
        # 4. Create epochs around events
        try:
            epochs = Epochs(raw_copy, events, event_id, tmin=tmin, tmax=tmax,
                          picks=picks, baseline=baseline, preload=True, verbose=False)
            
            # 5. Remove bad epochs (artifacts)
            epochs.drop_bad()
            
            if verbose:
                print(f"  ✅ Created {len(epochs)} epochs, shape: {epochs.get_data().shape}")
                print(f"  📊 Events per class: {dict(zip(epochs.event_id.keys(), 
                                                        [sum(epochs.events[:, 2] == v) for v in epochs.event_id.values()]))}")
            
            epochs_list.append(epochs)
            
        except Exception as e:
            if verbose:
                print(f"  ❌ Could not create epochs: {e}")
            continue
    
    return epochs_list

# Apply preprocessing
if raw_data:
    print("🔄 Starting preprocessing pipeline...")
    print("\n🎛️ Preprocessing Parameters:")
    print("• Bandpass filter: 7-30 Hz (removes artifacts, keeps motor rhythms)")
    print("• Notch filter: 50 Hz (removes power line noise)")
    print("• Epoch window: -1 to +4 seconds around event")
    print("• Baseline: -1 to 0 seconds (pre-stimulus period)")
    
    epochs_list = preprocess_eeg_data(raw_data, verbose=True)
    
    print(f"\n✅ Preprocessing completed for {len(epochs_list)} subjects")
    
    if epochs_list:
        # Display preprocessing results
        total_epochs = sum(len(epochs) for epochs in epochs_list)
        sample_epochs = epochs_list[0]
        
        print(f"\n📈 Preprocessing Summary:")
        print(f"Total epochs across all subjects: {total_epochs}")
        print(f"Epoch shape: {sample_epochs.get_data().shape}")
        print(f"Time points per epoch: {sample_epochs.get_data().shape[-1]}")
        print(f"Sampling rate: {sample_epochs.info['sfreq']} Hz")
        print(f"Time window: {sample_epochs.tmin} to {sample_epochs.tmax} seconds")
        
else:
    print("❌ No raw data available for preprocessing")
    epochs_list = []

In [None]:
# Cell 4: EEG Data Visualization and Analysis

print("📊 EEG Signal Visualization and Analysis")
print("=" * 40)

def visualize_eeg_data(epochs_list, max_subjects=2):
    """
    Create comprehensive visualizations of EEG data.
    """
    
    if not epochs_list:
        print("❌ No epochs data available for visualization")
        return
    
    # Use first subject for detailed analysis
    epochs = epochs_list[0]
    data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
    
    print(f"\n📊 Analyzing Subject 1 Data:")
    print(f"Data shape: {data.shape}")
    print(f"Event types: {list(epochs.event_id.keys())}")
    
    # 1. Plot average ERPs for each class
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Event-Related Potentials (ERPs) by Motor Imagery Task', fontsize=16)
    
    # Get some key channels for visualization
    key_channels = ['C3', 'C4', 'Cz', 'FC1', 'FC2']  # Motor cortex channels
    available_channels = [ch for ch in key_channels if ch in epochs.ch_names]
    
    if not available_channels:
        available_channels = epochs.ch_names[:5]  # Use first 5 channels if standard names not found
    
    print(f"\n🧠 Analyzing key motor cortex channels: {available_channels}")
    
    colors = ['blue', 'red', 'green', 'orange']
    event_names = list(epochs.event_id.keys())
    
    for i, ch_name in enumerate(available_channels[:4]):
        ax = axes[i//2, i%2]
        ch_idx = epochs.ch_names.index(ch_name)
        
        for j, (event_name, event_code) in enumerate(epochs.event_id.items()):
            # Get epochs for this event type
            event_epochs = epochs[event_name]
            if len(event_epochs) > 0:
                # Average across epochs
                avg_signal = event_epochs.get_data()[:, ch_idx, :].mean(axis=0)
                times = epochs.times
                
                ax.plot(times, avg_signal, color=colors[j % len(colors)], 
                       label=event_name, linewidth=2)
        
        ax.set_title(f'Channel {ch_name}')
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Amplitude (µV)')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, label='Event onset')
    
    plt.tight_layout()
    plt.show()
    
    # 2. Topographic maps showing spatial distribution
    if len(epochs.ch_names) > 10:  # Only if we have enough channels
        print("\n🗺️  Creating topographic maps...")
        
        fig, axes = plt.subplots(1, len(event_names), figsize=(4*len(event_names), 4))
        if len(event_names) == 1:
            axes = [axes]
        
        for i, (event_name, event_code) in enumerate(epochs.event_id.items()):
            event_epochs = epochs[event_name]
            if len(event_epochs) > 0:
                # Average over time window 0.5-2.0 seconds (motor imagery period)
                time_mask = (epochs.times >= 0.5) & (epochs.times <= 2.0)
                avg_topo = event_epochs.get_data()[:, :, time_mask].mean(axis=(0, 2))
                
                # Create topographic plot
                im, _ = mne.viz.plot_topomap(avg_topo, epochs.info, axes=axes[i], 
                                            show=False, cmap='RdBu_r')
                axes[i].set_title(f'{event_name}\n(0.5-2.0s avg)')
        
        plt.tight_layout()
        plt.show()
    
    # 3. Power Spectral Density Analysis
    print("\n⚡ Computing Power Spectral Density...")
    
    plt.figure(figsize=(12, 8))
    
    for i, (event_name, event_code) in enumerate(epochs.event_id.items()):
        event_epochs = epochs[event_name]
        if len(event_epochs) > 0:
            # Compute PSD using Welch's method
            psds, freqs = mne.time_frequency.psd_welch(event_epochs, fmin=1, fmax=40, 
                                                      n_fft=256, verbose=False)
            # Average across channels and epochs
            avg_psd = psds.mean(axis=(0, 1))
            
            plt.semilogy(freqs, avg_psd, color=colors[i % len(colors)], 
                        label=event_name, linewidth=2)
    
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power Spectral Density (µV²/Hz)')
    plt.title('Average Power Spectral Density by Motor Imagery Task')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Highlight important frequency bands
    plt.axvspan(8, 13, alpha=0.2, color='yellow', label='Alpha (8-13 Hz)')
    plt.axvspan(13, 30, alpha=0.2, color='cyan', label='Beta (13-30 Hz)')
    
    plt.tight_layout()
    plt.show()
    
    return data

# Create visualizations
if epochs_list:
    print("🎨 Creating EEG visualizations...")
    print("\n📈 This will show:")
    print("1. Event-Related Potentials (ERPs) for motor cortex channels")
    print("2. Topographic maps showing spatial activation patterns")
    print("3. Power spectral density analysis")
    
    epoch_data = visualize_eeg_data(epochs_list)
    
    print("\n✅ Visualization completed!")
    print("\n🧠 Key Observations to Look For:")
    print("• Different ERP patterns between left/right hand imagery")
    print("• Lateralized activation (left motor cortex for right hand, vice versa)")
    print("• Alpha/beta band power differences between tasks")
    print("• Event-related desynchronization (ERD) in motor frequencies")
    
else:
    print("❌ No epoch data available for visualization")
    epoch_data = None

In [None]:
# Cell 5: EEG Channel Selection Techniques

print("🎯 EEG Channel Selection for Motor Imagery")
print("=" * 45)

def analyze_channel_importance(epochs_list, method='variance', top_k=16):
    """
    Analyze and select the most important EEG channels for motor imagery classification.
    
    Parameters:
    -----------
    epochs_list : list
        List of MNE Epochs objects
    method : str
        Channel selection method ('variance', 'motor_cortex', 'statistical')
    top_k : int
        Number of top channels to select
        
    Returns:
    --------
    selected_channels : list
        Names of selected channels
    channel_scores : dict
        Importance scores for each channel
    """
    
    if not epochs_list:
        return [], {}
    
    # Combine data from all subjects
    all_data = []
    all_labels = []
    
    for epochs in epochs_list:
        data = epochs.get_data()  # (n_epochs, n_channels, n_times)
        labels = epochs.events[:, 2]  # Event codes
        
        all_data.append(data)
        all_labels.append(labels)
    
    # Concatenate all subjects
    X = np.concatenate(all_data, axis=0)  # (total_epochs, n_channels, n_times)
    y = np.concatenate(all_labels, axis=0)  # (total_epochs,)
    
    print(f"\n📊 Combined dataset shape: {X.shape}")
    print(f"📊 Labels shape: {y.shape}")
    print(f"📊 Unique classes: {np.unique(y)}")
    
    channel_names = epochs_list[0].ch_names
    channel_scores = {}
    
    if method == 'motor_cortex':
        print("\n🧠 Method: Motor Cortex Channel Selection")
        print("Selecting channels over motor and sensorimotor areas...")
        
        # Define motor cortex channels (based on 10-20 system)
        motor_channels = [
            'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6',  # Frontal motor
            'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6',        # Central motor
            'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6'   # Parietal motor
        ]
        
        # Find available motor channels
        available_motor = [ch for ch in motor_channels if ch in channel_names]
        
        # If not enough motor channels, add nearby channels
        if len(available_motor) < top_k:
            additional_channels = ['F3', 'F4', 'P3', 'P4', 'T7', 'T8', 'Fz', 'Pz']
            for ch in additional_channels:
                if ch in channel_names and ch not in available_motor:
                    available_motor.append(ch)
                    if len(available_motor) >= top_k:
                        break
        
        selected_channels = available_motor[:top_k]
        
        # Assign scores based on motor relevance
        for i, ch in enumerate(channel_names):
            if ch in selected_channels:
                channel_scores[ch] = 1.0 - (selected_channels.index(ch) / len(selected_channels))
            else:
                channel_scores[ch] = 0.0
                
    elif method == 'variance':
        print("\n📈 Method: Variance-Based Channel Selection")
        print("Selecting channels with highest signal variance...")
        
        # Calculate variance for each channel across all epochs and time
        channel_variances = np.var(X, axis=(0, 2))  # Variance across epochs and time
        
        # Rank channels by variance
        channel_ranking = np.argsort(channel_variances)[::-1]  # Descending order
        selected_indices = channel_ranking[:top_k]
        selected_channels = [channel_names[i] for i in selected_indices]
        
        # Store scores
        for i, ch in enumerate(channel_names):
            channel_scores[ch] = channel_variances[i]
            
    elif method == 'statistical':
        print("\n📊 Method: Statistical Channel Selection (F-score)")
        print("Selecting channels that best discriminate between classes...")
        
        # Flatten temporal dimension for statistical analysis
        X_flat = X.reshape(X.shape[0], -1)  # (n_epochs, n_channels * n_times)
        
        # Calculate F-score for each feature (channel x time)
        f_scores, _ = f_classif(X_flat, y)
        
        # Reshape back to (n_channels, n_times) and average over time
        f_scores_reshaped = f_scores.reshape(X.shape[1], X.shape[2])
        channel_f_scores = np.mean(f_scores_reshaped, axis=1)
        
        # Rank channels by F-score
        channel_ranking = np.argsort(channel_f_scores)[::-1]
        selected_indices = channel_ranking[:top_k]
        selected_channels = [channel_names[i] for i in selected_indices]
        
        # Store scores
        for i, ch in enumerate(channel_names):
            channel_scores[ch] = channel_f_scores[i]
    
    return selected_channels, channel_scores

def visualize_channel_selection(epochs_list, selected_channels, channel_scores, method):
    """
    Visualize selected channels and their importance scores.
    """
    
    if not epochs_list or not selected_channels:
        return
    
    epochs = epochs_list[0]
    
    # 1. Plot channel importance scores
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Bar plot of top channels
    top_channels = selected_channels[:16]
    top_scores = [channel_scores[ch] for ch in top_channels]
    
    ax1.bar(range(len(top_channels)), top_scores, color='steelblue', alpha=0.7)
    ax1.set_xlabel('Channel Index')
    ax1.set_ylabel('Importance Score')
    ax1.set_title(f'Top {len(top_channels)} Channels - {method.title()} Method')
    ax1.set_xticks(range(len(top_channels)))
    ax1.set_xticklabels(top_channels, rotation=45)
    ax1.grid(True, alpha=0.3)
    
    # 2. Topographic map of channel importance
    if len(epochs.ch_names) > 10:
        try:
            # Create importance vector for all channels
            importance_vector = np.array([channel_scores.get(ch, 0) for ch in epochs.ch_names])
            
            # Normalize for better visualization
            if np.max(importance_vector) > 0:
                importance_vector = importance_vector / np.max(importance_vector)
            
            # Plot topographic map
            im, _ = mne.viz.plot_topomap(importance_vector, epochs.info, axes=ax2, 
                                        show=False, cmap='Reds', vmin=0, vmax=1)
            ax2.set_title(f'Channel Importance Map\n({method.title()} Method)')
            
            # Add colorbar
            cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
            cbar.set_label('Importance Score')
            
        except Exception as e:
            ax2.text(0.5, 0.5, f'Topographic plot not available\n{str(e)}', 
                    transform=ax2.transAxes, ha='center', va='center')
            ax2.set_title('Topographic Map (Not Available)')
    else:
        ax2.text(0.5, 0.5, 'Insufficient channels\nfor topographic map', 
                transform=ax2.transAxes, ha='center', va='center')
        ax2.set_title('Topographic Map (Not Available)')
    
    plt.tight_layout()
    plt.show()

# Apply different channel selection methods
if epochs_list:
    print("🔍 Applying multiple channel selection methods...")
    
    methods = ['motor_cortex', 'variance', 'statistical']
    selection_results = {}
    
    for method in methods:
        print(f"\n{'-'*50}")
        selected_channels, channel_scores = analyze_channel_importance(epochs_list, method=method, top_k=16)
        
        if selected_channels:
            selection_results[method] = {
                'channels': selected_channels,
                'scores': channel_scores
            }
            
            print(f"\n✅ Selected {len(selected_channels)} channels using {method} method:")
            print(f"Top 10: {selected_channels[:10]}")
            
            # Visualize results
            visualize_channel_selection(epochs_list, selected_channels, channel_scores, method)
    
    # Compare methods
    if len(selection_results) > 1:
        print(f"\n🔍 Comparing Channel Selection Methods:")
        print(f"{'Method':<15} {'Top 5 Channels':<50}")
        print(f"{'-'*65}")
        
        for method, results in selection_results.items():
            top5 = ', '.join(results['channels'][:5])
            print(f"{method:<15} {top5:<50}")
        
        # Find common channels across methods
        all_channels = [set(results['channels'][:10]) for results in selection_results.values()]
        common_channels = set.intersection(*all_channels)
        
        print(f"\n🎯 Channels selected by ALL methods: {sorted(common_channels)}")
        print(f"\n💡 Recommendation: Use motor_cortex method for interpretability")
        print(f"   or statistical method for best discrimination performance.")
    
    # Store best selection for later use
    best_method = 'statistical' if 'statistical' in selection_results else list(selection_results.keys())[0]
    best_channels = selection_results[best_method]['channels'][:16]
    
    print(f"\n🏆 Using {best_method} method with {len(best_channels)} channels for CNN training")
    
else:
    print("❌ No epoch data available for channel selection")
    selection_results = {}
    best_channels = []

In [None]:
# Cell 6: Prepare Data for CNN Training

print("🔧 Preparing EEG Data for CNN Training")
print("=" * 40)

def prepare_cnn_data(epochs_list, selected_channels=None, test_size=0.2, val_size=0.2):
    """
    Prepare EEG data for CNN training.
    
    Parameters:
    -----------
    epochs_list : list
        List of MNE Epochs objects
    selected_channels : list
        Names of selected channels to use
    test_size : float
        Fraction of data to use for testing
    val_size : float
        Fraction of training data to use for validation
        
    Returns:
    --------
    X_train, X_val, X_test : numpy arrays
        Training, validation, and test data
    y_train, y_val, y_test : numpy arrays
        Training, validation, and test labels
    class_names : list
        Names of the classes
    """
    
    if not epochs_list:
        return None, None, None, None, None, None, None
    
    # Combine data from all subjects
    all_data = []
    all_labels = []
    all_subjects = []
    
    for subject_idx, epochs in enumerate(epochs_list):
        data = epochs.get_data()  # (n_epochs, n_channels, n_times)
        labels = epochs.events[:, 2]  # Event codes
        
        # Select specific channels if provided
        if selected_channels:
            channel_indices = [epochs.ch_names.index(ch) for ch in selected_channels 
                             if ch in epochs.ch_names]
            if channel_indices:
                data = data[:, channel_indices, :]
                print(f"\n📡 Subject {subject_idx+1}: Using {len(channel_indices)} selected channels")
            else:
                print(f"\n⚠️  Subject {subject_idx+1}: No selected channels found, using all channels")
        
        all_data.append(data)
        all_labels.append(labels)
        all_subjects.extend([subject_idx] * len(data))
    
    # Concatenate all subjects
    X = np.concatenate(all_data, axis=0)  # (total_epochs, n_channels, n_times)
    y = np.concatenate(all_labels, axis=0)  # (total_epochs,)
    subjects = np.array(all_subjects)  # (total_epochs,)
    
    print(f"\n📊 Combined Dataset Information:")
    print(f"Total epochs: {X.shape[0]}")
    print(f"Channels: {X.shape[1]}")
    print(f"Time points: {X.shape[2]}")
    print(f"Data shape: {X.shape}")
    print(f"Unique classes: {np.unique(y)}")
    
    # Create class names mapping
    event_id = epochs_list[0].event_id
    class_mapping = {v: k for k, v in event_id.items()}
    class_names = [class_mapping[label] for label in sorted(np.unique(y))]
    
    print(f"\n🎯 Class Mapping:")
    for i, (label, name) in enumerate(zip(sorted(np.unique(y)), class_names)):
        count = np.sum(y == label)
        print(f"  {label} → {name}: {count} epochs ({count/len(y)*100:.1f}%)")
    
    # Normalize labels to start from 0
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)
    
    # Stratified split to maintain class balance
    print(f"\n🔄 Splitting data: {1-test_size:.0%} train, {test_size:.0%} test")
    X_train_temp, X_test, y_train_temp, y_test, subj_train, subj_test = train_test_split(
        X, y_encoded, subjects, test_size=test_size, stratify=y_encoded, random_state=42
    )
    
    # Further split training data into train and validation
    print(f"🔄 Splitting training data: {1-val_size:.0%} train, {val_size:.0%} validation")
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_temp, y_train_temp, test_size=val_size, stratify=y_train_temp, random_state=42
    )
    
    # Data normalization (z-score normalization)
    print(f"\n📏 Applying z-score normalization...")
    
    # Calculate statistics from training data only
    train_mean = np.mean(X_train, axis=(0, 2), keepdims=True)  # Mean across epochs and time
    train_std = np.std(X_train, axis=(0, 2), keepdims=True)    # Std across epochs and time
    
    # Apply normalization
    X_train_norm = (X_train - train_mean) / (train_std + 1e-8)
    X_val_norm = (X_val - train_mean) / (train_std + 1e-8)
    X_test_norm = (X_test - train_mean) / (train_std + 1e-8)
    
    # Convert labels to categorical for CNN
    num_classes = len(np.unique(y_encoded))
    y_train_cat = to_categorical(y_train, num_classes)
    y_val_cat = to_categorical(y_val, num_classes)
    y_test_cat = to_categorical(y_test, num_classes)
    
    print(f"\n✅ Data preparation completed!")
    print(f"📊 Final shapes:")
    print(f"  Training: X={X_train_norm.shape}, y={y_train_cat.shape}")
    print(f"  Validation: X={X_val_norm.shape}, y={y_val_cat.shape}")
    print(f"  Test: X={X_test_norm.shape}, y={y_test_cat.shape}")
    print(f"  Number of classes: {num_classes}")
    
    # Display data statistics
    print(f"\n📈 Data Statistics (after normalization):")
    print(f"  Training data mean: {np.mean(X_train_norm):.6f}")
    print(f"  Training data std: {np.std(X_train_norm):.6f}")
    print(f"  Data range: [{np.min(X_train_norm):.3f}, {np.max(X_train_norm):.3f}]")
    
    return X_train_norm, X_val_norm, X_test_norm, y_train_cat, y_val_cat, y_test_cat, class_names

# Prepare data for CNN training
if epochs_list:
    print("🔄 Preparing data for CNN training...")
    
    # Use selected channels if available
    channels_to_use = best_channels if best_channels else None
    
    if channels_to_use:
        print(f"\n📡 Using {len(channels_to_use)} selected channels: {channels_to_use[:5]}...")
    else:
        print(f"\n📡 Using all available channels")
    
    # Prepare data
    X_train, X_val, X_test, y_train, y_val, y_test, class_names = prepare_cnn_data(
        epochs_list, selected_channels=channels_to_use, test_size=0.2, val_size=0.2
    )
    
    if X_train is not None:
        print(f"\n🎯 Ready for CNN training!")
        print(f"Motor imagery classes: {class_names}")
        
        # Store data info for later use
        data_info = {
            'n_channels': X_train.shape[1],
            'n_timepoints': X_train.shape[2],
            'n_classes': len(class_names),
            'class_names': class_names,
            'sampling_rate': epochs_list[0].info['sfreq'],
            'selected_channels': channels_to_use
        }
        
        print(f"\n📋 Dataset Summary:")
        for key, value in data_info.items():
            if key != 'selected_channels':
                print(f"  {key}: {value}")
    
    else:
        print("❌ Failed to prepare data for CNN training")
        
else:
    print("❌ No epoch data available for CNN preparation")
    X_train = X_val = X_test = y_train = y_val = y_test = class_names = None
    data_info = None

In [None]:
# Cell 7: Design CNN Architectures for EEG

print("🏗️ Designing CNN Architectures for EEG Motor Imagery")
print("=" * 55)

def create_eeg_cnn_1d(input_shape, num_classes, name="EEG_CNN_1D"):
    """
    Create 1D CNN for temporal EEG signal processing.
    
    This architecture focuses on temporal patterns in EEG signals.
    """
    
    print(f"\n🔧 Building {name}...")
    print("Architecture: 1D Convolutions for temporal pattern extraction")
    
    model = models.Sequential([
        # Input layer
        layers.Input(shape=input_shape, name='input'),
        
        # Reshape for 1D convolution (channels, time) -> (time, channels)
        layers.Permute((2, 1), name='permute_for_1d'),
        
        # First temporal convolution block
        layers.Conv1D(filters=32, kernel_size=64, padding='same', name='temp_conv1'),
        layers.BatchNormalization(name='bn1'),
        layers.Activation('relu', name='relu1'),
        layers.MaxPooling1D(pool_size=4, name='pool1'),
        layers.Dropout(0.2, name='dropout1'),
        
        # Second temporal convolution block
        layers.Conv1D(filters=64, kernel_size=32, padding='same', name='temp_conv2'),
        layers.BatchNormalization(name='bn2'),
        layers.Activation('relu', name='relu2'),
        layers.MaxPooling1D(pool_size=4, name='pool2'),
        layers.Dropout(0.3, name='dropout2'),
        
        # Third temporal convolution block
        layers.Conv1D(filters=128, kernel_size=16, padding='same', name='temp_conv3'),
        layers.BatchNormalization(name='bn3'),
        layers.Activation('relu', name='relu3'),
        layers.MaxPooling1D(pool_size=2, name='pool3'),
        layers.Dropout(0.4, name='dropout3'),
        
        # Global average pooling
        layers.GlobalAveragePooling1D(name='global_avg_pool'),
        
        # Classification head
        layers.Dense(256, name='dense1'),
        layers.BatchNormalization(name='bn_dense1'),
        layers.Activation('relu', name='relu_dense1'),
        layers.Dropout(0.5, name='dropout_dense1'),
        
        layers.Dense(128, name='dense2'),
        layers.BatchNormalization(name='bn_dense2'),
        layers.Activation('relu', name='relu_dense2'),
        layers.Dropout(0.5, name='dropout_dense2'),
        
        # Output layer
        layers.Dense(num_classes, activation='softmax', name='output')
    ], name=name)
    
    return model

def create_eeg_cnn_2d(input_shape, num_classes, name="EEG_CNN_2D"):
    """
    Create 2D CNN for spatial-temporal EEG signal processing.
    
    This architecture treats EEG as 2D image (channels × time).
    """
    
    print(f"\n🔧 Building {name}...")
    print("Architecture: 2D Convolutions for spatial-temporal pattern extraction")
    
    model = models.Sequential([
        # Input layer
        layers.Input(shape=input_shape, name='input'),
        
        # Reshape to add channel dimension for 2D convolution
        layers.Reshape((*input_shape, 1), name='reshape_2d'),
        
        # First spatial-temporal convolution block
        layers.Conv2D(filters=32, kernel_size=(8, 32), padding='same', name='spattemp_conv1'),
        layers.BatchNormalization(name='bn1'),
        layers.Activation('relu', name='relu1'),
        layers.MaxPooling2D(pool_size=(2, 4), name='pool1'),
        layers.Dropout(0.2, name='dropout1'),
        
        # Second spatial-temporal convolution block
        layers.Conv2D(filters=64, kernel_size=(4, 16), padding='same', name='spattemp_conv2'),
        layers.BatchNormalization(name='bn2'),
        layers.Activation('relu', name='relu2'),
        layers.MaxPooling2D(pool_size=(2, 4), name='pool2'),
        layers.Dropout(0.3, name='dropout2'),
        
        # Third spatial-temporal convolution block
        layers.Conv2D(filters=128, kernel_size=(2, 8), padding='same', name='spattemp_conv3'),
        layers.BatchNormalization(name='bn3'),
        layers.Activation('relu', name='relu3'),
        layers.MaxPooling2D(pool_size=(1, 2), name='pool3'),
        layers.Dropout(0.4, name='dropout3'),
        
        # Global average pooling
        layers.GlobalAveragePooling2D(name='global_avg_pool'),
        
        # Classification head
        layers.Dense(256, name='dense1'),
        layers.BatchNormalization(name='bn_dense1'),
        layers.Activation('relu', name='relu_dense1'),
        layers.Dropout(0.5, name='dropout_dense1'),
        
        layers.Dense(128, name='dense2'),
        layers.BatchNormalization(name='bn_dense2'),
        layers.Activation('relu', name='relu_dense2'),
        layers.Dropout(0.5, name='dropout_dense2'),
        
        # Output layer
        layers.Dense(num_classes, activation='softmax', name='output')
    ], name=name)
    
    return model

def create_eeg_cnn_hybrid(input_shape, num_classes, name="EEG_CNN_Hybrid"):
    """
    Create hybrid CNN combining spatial and temporal processing.
    
    This architecture first extracts spatial patterns, then temporal patterns.
    """
    
    print(f"\n🔧 Building {name}...")
    print("Architecture: Hybrid spatial-first then temporal convolutions")
    
    # Input
    input_layer = layers.Input(shape=input_shape, name='input')
    
    # Reshape for 2D convolution
    x = layers.Reshape((*input_shape, 1), name='reshape_2d')(input_layer)
    
    # Spatial convolution (across channels)
    x = layers.Conv2D(filters=32, kernel_size=(input_shape[0], 1), 
                     padding='valid', name='spatial_conv')(x)
    x = layers.BatchNormalization(name='bn_spatial')(x)
    x = layers.Activation('relu', name='relu_spatial')(x)
    x = layers.Dropout(0.2, name='dropout_spatial')(x)
    
    # Reshape for temporal processing
    x = layers.Reshape((input_shape[1], 32), name='reshape_temporal')(x)
    
    # Temporal convolutions
    x = layers.Conv1D(filters=64, kernel_size=32, padding='same', name='temp_conv1')(x)
    x = layers.BatchNormalization(name='bn_temp1')(x)
    x = layers.Activation('relu', name='relu_temp1')(x)
    x = layers.MaxPooling1D(pool_size=4, name='pool_temp1')(x)
    x = layers.Dropout(0.3, name='dropout_temp1')(x)
    
    x = layers.Conv1D(filters=128, kernel_size=16, padding='same', name='temp_conv2')(x)
    x = layers.BatchNormalization(name='bn_temp2')(x)
    x = layers.Activation('relu', name='relu_temp2')(x)
    x = layers.MaxPooling1D(pool_size=4, name='pool_temp2')(x)
    x = layers.Dropout(0.4, name='dropout_temp2')(x)
    
    # Global pooling
    x = layers.GlobalAveragePooling1D(name='global_avg_pool')(x)
    
    # Classification head
    x = layers.Dense(256, name='dense1')(x)
    x = layers.BatchNormalization(name='bn_dense1')(x)
    x = layers.Activation('relu', name='relu_dense1')(x)
    x = layers.Dropout(0.5, name='dropout_dense1')(x)
    
    x = layers.Dense(128, name='dense2')(x)
    x = layers.BatchNormalization(name='bn_dense2')(x)
    x = layers.Activation('relu', name='relu_dense2')(x)
    x = layers.Dropout(0.5, name='dropout_dense2')(x)
    
    # Output
    output = layers.Dense(num_classes, activation='softmax', name='output')(x)
    
    model = models.Model(inputs=input_layer, outputs=output, name=name)
    
    return model

# Create CNN models
if X_train is not None and data_info is not None:
    print("🚀 Creating CNN architectures for EEG classification...")
    
    input_shape = (data_info['n_channels'], data_info['n_timepoints'])
    num_classes = data_info['n_classes']
    
    print(f"\n📊 Model Configuration:")
    print(f"Input shape: {input_shape} (channels, time_points)")
    print(f"Number of classes: {num_classes}")
    print(f"Classes: {data_info['class_names']}")
    
    # Create different CNN architectures
    models_dict = {}
    
    # 1D CNN
    cnn_1d = create_eeg_cnn_1d(input_shape, num_classes)
    cnn_1d.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    models_dict['1D_CNN'] = cnn_1d
    
    # 2D CNN
    cnn_2d = create_eeg_cnn_2d(input_shape, num_classes)
    cnn_2d.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    models_dict['2D_CNN'] = cnn_2d
    
    # Hybrid CNN
    cnn_hybrid = create_eeg_cnn_hybrid(input_shape, num_classes)
    cnn_hybrid.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    models_dict['Hybrid_CNN'] = cnn_hybrid
    
    print(f"\n✅ Created {len(models_dict)} CNN architectures!")
    
    # Display model summaries
    for name, model in models_dict.items():
        print(f"\n{'-'*50}")
        print(f"📋 {name} ARCHITECTURE")
        print(f"{'-'*50}")
        model.summary()
        
        # Count parameters
        total_params = model.count_params()
        print(f"\n📊 {name} Parameters: {total_params:,}")
    
    print(f"\n🎯 CNN Architecture Comparison:")
    print(f"{'Model':<15} {'Parameters':<12} {'Focus':<30}")
    print(f"{'-'*60}")
    print(f"{'1D_CNN':<15} {models_dict['1D_CNN'].count_params():<12,} {'Temporal patterns':<30}")
    print(f"{'2D_CNN':<15} {models_dict['2D_CNN'].count_params():<12,} {'Spatial-temporal patterns':<30}")
    print(f"{'Hybrid_CNN':<15} {models_dict['Hybrid_CNN'].count_params():<12,} {'Spatial then temporal':<30}")
    
    print(f"\n💡 Architecture Insights:")
    print(f"• 1D CNN: Best for capturing temporal dynamics in EEG")
    print(f"• 2D CNN: Captures both spatial and temporal relationships")
    print(f"• Hybrid CNN: Explicit spatial-temporal separation")
    print(f"• All use BatchNorm + Dropout for regularization")
    print(f"• Global pooling reduces overfitting")
    
else:
    print("❌ Cannot create CNN models - no training data available")
    models_dict = {}

In [None]:
# Cell 8: Train and Evaluate CNN Models

print("🚀 Training CNN Models for EEG Motor Imagery Classification")
print("=" * 60)

def train_eeg_model(model, X_train, y_train, X_val, y_val, model_name, epochs=50):
    """
    Train an EEG CNN model with comprehensive monitoring.
    """
    
    print(f"\n🔥 Training {model_name}...")
    
    # Create callbacks
    callbacks_list = [
        keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7,
            verbose=1
        ),
        keras.callbacks.ModelCheckpoint(
            f'best_{model_name.lower()}_model.h5',
            monitor='val_accuracy',
            save_best_only=True,
            save_weights_only=True,
            verbose=0
        )
    ]
    
    # Train model
    history = model.fit(
        X_train, y_train,
        batch_size=32,
        epochs=epochs,
        validation_data=(X_val, y_val),
        callbacks=callbacks_list,
        verbose=1
    )
    
    return history

def evaluate_model(model, X_test, y_test, class_names, model_name):
    """
    Comprehensive model evaluation.
    """
    
    print(f"\n📊 Evaluating {model_name}...")
    
    # Get predictions
    y_pred_proba = model.predict(X_test, verbose=0)
    y_pred = np.argmax(y_pred_proba, axis=1)
    y_true = np.argmax(y_test, axis=1)
    
    # Calculate metrics
    test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
    
    print(f"\n📈 {model_name} Results:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
    
    # Classification report
    print(f"\n📋 Detailed Classification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'{model_name} - Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()
    
    return {
        'test_loss': test_loss,
        'test_accuracy': test_accuracy,
        'y_pred': y_pred,
        'y_true': y_true,
        'y_pred_proba': y_pred_proba,
        'confusion_matrix': cm
    }

def plot_training_history(histories, model_names):
    """
    Plot training histories for all models.
    """
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('EEG CNN Training Comparison', fontsize=16)
    
    colors = ['blue', 'red', 'green', 'orange']
    
    # Training accuracy
    axes[0, 0].set_title('Training Accuracy')
    for i, (name, history) in enumerate(zip(model_names, histories)):
        if history:
            axes[0, 0].plot(history.history['accuracy'], color=colors[i % len(colors)], 
                           label=f'{name}', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Validation accuracy
    axes[0, 1].set_title('Validation Accuracy')
    for i, (name, history) in enumerate(zip(model_names, histories)):
        if history:
            axes[0, 1].plot(history.history['val_accuracy'], color=colors[i % len(colors)], 
                           label=f'{name}', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Validation Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Training loss
    axes[1, 0].set_title('Training Loss')
    for i, (name, history) in enumerate(zip(model_names, histories)):
        if history:
            axes[1, 0].plot(history.history['loss'], color=colors[i % len(colors)], 
                           label=f'{name}', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Validation loss
    axes[1, 1].set_title('Validation Loss')
    for i, (name, history) in enumerate(zip(model_names, histories)):
        if history:
            axes[1, 1].plot(history.history['val_loss'], color=colors[i % len(colors)], 
                           label=f'{name}', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Validation Loss')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Train and evaluate models
if X_train is not None and models_dict:
    print("🎯 Starting comprehensive CNN training and evaluation...")
    print(f"\n📊 Training Configuration:")
    print(f"Training samples: {X_train.shape[0]}")
    print(f"Validation samples: {X_val.shape[0]}")
    print(f"Test samples: {X_test.shape[0]}")
    print(f"Maximum epochs: 50 (with early stopping)")
    print(f"Batch size: 32")
    
    # Train all models
    training_histories = []
    evaluation_results = {}
    
    for model_name, model in models_dict.items():
        print(f"\n{'='*60}")
        print(f"🔥 TRAINING {model_name}")
        print(f"{'='*60}")
        
        # Train model
        history = train_eeg_model(model, X_train, y_train, X_val, y_val, model_name, epochs=50)
        training_histories.append(history)
        
        # Evaluate model
        results = evaluate_model(model, X_test, y_test, class_names, model_name)
        evaluation_results[model_name] = results
        
        print(f"\n✅ {model_name} training and evaluation completed!")
    
    # Plot training comparison
    print(f"\n📈 Creating training comparison plots...")
    plot_training_history(training_histories, list(models_dict.keys()))
    
    # Model comparison summary
    print(f"\n🏆 MODEL COMPARISON SUMMARY")
    print(f"{'='*60}")
    print(f"{'Model':<15} {'Test Accuracy':<15} {'Test Loss':<12}")
    print(f"{'-'*45}")
    
    best_accuracy = 0
    best_model = None
    
    for model_name, results in evaluation_results.items():
        accuracy = results['test_accuracy']
        loss = results['test_loss']
        print(f"{model_name:<15} {accuracy:<15.4f} {loss:<12.4f}")
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model = model_name
    
    print(f"\n🥇 Best Model: {best_model} with {best_accuracy:.4f} ({best_accuracy*100:.2f}%) accuracy")
    
    print(f"\n🧠 EEG CNN Classification Insights:")
    print(f"• Motor imagery classification achieved with CNN")
    print(f"• Different CNN architectures capture different aspects")
    print(f"• Channel selection improved performance")
    print(f"• Temporal patterns are crucial for EEG classification")
    
    print(f"\n✅ CNN training and evaluation completed successfully!")
    
else:
    print("❌ Cannot train models - no data or models available")
    evaluation_results = {}
    best_model = None

In [None]:
# Cell 9: Final Summary and Insights

print("🎓 EEG Motor Imagery CNN Classification - Complete Summary")
print("=" * 65)

if X_train is not None and evaluation_results:
    print(f"\n🧠 WHAT WE ACCOMPLISHED:")
    print(f"✅ Loaded PhysioNet EEG Motor Imagery dataset")
    print(f"✅ Preprocessed EEG signals (filtering, epoching, normalization)")
    print(f"✅ Analyzed and selected optimal EEG channels")
    print(f"✅ Designed and trained 3 different CNN architectures")
    print(f"✅ Achieved motor imagery classification")
    
    print(f"\n📊 DATASET SUMMARY:")
    print(f"• Subjects processed: {len(epochs_list)}")
    print(f"• Total epochs: {X_train.shape[0] + X_val.shape[0] + X_test.shape[0]}")
    print(f"• EEG channels used: {data_info['n_channels']}")
    print(f"• Time points per epoch: {data_info['n_timepoints']}")
    print(f"• Motor imagery classes: {len(data_info['class_names'])}")
    print(f"• Class names: {', '.join(data_info['class_names'])}")
    
    print(f"\n🏆 FINAL RESULTS:")
    for model_name, results in evaluation_results.items():
        accuracy = results['test_accuracy']
        print(f"• {model_name}: {accuracy:.4f} ({accuracy*100:.2f}%) accuracy")
    
    if best_model:
        best_acc = evaluation_results[best_model]['test_accuracy']
        print(f"\n🥇 Best performing model: {best_model} ({best_acc*100:.2f}% accuracy)")
    
    print(f"\n🔍 KEY INSIGHTS LEARNED:")
    print(f"\n1. 🧠 EEG Signal Processing:")
    print(f"   • EEG signals contain rich temporal patterns")
    print(f"   • Proper filtering (7-30 Hz) removes artifacts")
    print(f"   • Motor cortex channels are most informative")
    
    print(f"\n2. 🎯 Channel Selection:")
    print(f"   • Not all EEG channels are equally important")
    print(f"   • Motor cortex channels (C3, C4, Cz) are crucial")
    print(f"   • Statistical selection can improve performance")
    
    print(f"\n3. 🏗️ CNN Architecture Design:")
    print(f"   • 1D CNNs excel at temporal pattern extraction")
    print(f"   • 2D CNNs capture spatial-temporal relationships")
    print(f"   • Hybrid approaches combine best of both worlds")
    
    print(f"\n4. 🎮 Motor Imagery Classification:")
    print(f"   • Different motor imagery tasks create distinct patterns")
    print(f"   • CNNs can automatically learn these patterns")
    print(f"   • Real-time BCI applications are possible")
    
    print(f"\n🚀 NEXT STEPS - GRAPH CNN COMPARISON:")
    print(f"📋 Coming next: Graph CNN implementation")
    print(f"• Traditional CNN treats channels independently")
    print(f"• Graph CNN models spatial relationships between channels")
    print(f"• Expected benefits: Better spatial feature extraction")
    print(f"• Channel connectivity based on brain anatomy")
    
    print(f"\n💡 REAL-WORLD APPLICATIONS:")
    print(f"🔹 Brain-Computer Interfaces (BCI)")
    print(f"🔹 Assistive technology for paralyzed patients")
    print(f"🔹 Neurofeedback training")
    print(f"🔹 Cognitive load monitoring")
    print(f"🔹 Mental state detection")
    
    print(f"\n📚 TECHNICAL KNOWLEDGE GAINED:")
    print(f"✅ EEG signal processing pipeline")
    print(f"✅ Motor imagery neuroscience concepts")
    print(f"✅ CNN architecture design for time series")
    print(f"✅ Channel selection techniques")
    print(f"✅ Model evaluation and comparison")
    
else:
    print(f"\n⚠️  Training was not completed successfully.")
    print(f"This could be due to:")
    print(f"• Network connectivity issues (dataset download)")
    print(f"• Insufficient memory")
    print(f"• Missing dependencies")
    print(f"\nPlease check the error messages above and ensure:")
    print(f"• Stable internet connection")
    print(f"• All required packages are installed")
    print(f"• Sufficient system memory")

print(f"\n🎯 PREPARATION FOR GRAPH CNN:")
print(f"The traditional CNN approach provides our baseline.")
print(f"Next, we'll implement Graph CNN to demonstrate:")
print(f"• How spatial relationships between EEG channels matter")
print(f"• Why brain anatomy should inform model architecture")
print(f"• How Graph Neural Networks can outperform traditional CNNs")

print(f"\n" + "="*65)
print(f"🧠 EEG CNN CLASSIFICATION COMPLETE - READY FOR GRAPH CNN! 🧠")
print(f"="*65)