In [6]:
import mne
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import ShuffleSplit, cross_val_score, cross_val_predict, StratifiedKFold
from sklearn.multiclass import OneVsRestClassifier
from mne.decoding import CSP
from mne.preprocessing import ICA
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from joblib import dump
import warnings
import os

# Suppress warnings for cleaner academic output
warnings.filterwarnings("ignore", category=RuntimeWarning)
plt.ioff() # Turn off interactive plotting (Crucial for non-blocking execution)

# --- 1. CONFIGURATION PARAMETERS ---
# Optimized parameters for Subject A01T (BCIC IV-2a dataset)
DATA_PATH_GDF = 'C:/Users/user/Desktop/BCI project/BCICIV_2a_gdf (1)/A01T.gdf'

# Hyperparameters
N_CSP_COMPONENTS = 8   # Number of components for CSP (4 extreme pairs)
K_BEST_FEATURES = 32   # Number of top features selected (Mutual Information)
N_SHUFFLE_SPLITS = 15  # CV splits for Mean Accuracy calculation (Reported Result)
N_KFOLD_SPLITS = 10    # CV splits for Confusion Matrix (CM)
T_MIN = 0.5            # Start of MI time window (s)
T_MAX = 3.0            # End of MI time window (s)
T_MIN_DEMO = -0.5      # Start of demo window (for baseline visualization)
T_MAX_DEMO = 4.0       # End of demo window

# Optimized Filter Bank frequency bands (8 sub-bands)
FREQ_BANDS = [         
    [4., 8.], [8., 12.], [10., 14.], [12., 16.],
    [16., 20.], [20., 24.], [22., 26.], [26., 30.]
]
TARGET_CLASS_NAMES = ['Left_hand', 'Right_hand', 'Foot', 'Tongue']
TARGET_EVENT_CODES = {'Left_hand': 769, 'Right_hand': 770, 'Foot': 771, 'Tongue': 772}


def load_and_preprocess_data(data_path):
    """Loads GDF data, applies ICA artifact removal, and CAR referencing."""
    print("--- 2. DATA LOADING & PREPROCESSING ---")
    raw = mne.io.read_raw_gdf(data_path, preload=True, verbose=False)
    events, event_id = mne.events_from_annotations(raw)

    if len(raw.ch_names) != len(set(raw.ch_names)):
        channel_map = {ch: f'EEG_{i}' if ch.startswith('EEG') else ch for i, ch in enumerate(raw.ch_names)}
        raw.rename_channels(channel_map)

    eog_channels = ['EOG-left', 'EOG-central', 'EOG-right']
    eog_map = {ch: 'eog' for ch in eog_channels if ch in raw.ch_names}
    if eog_map:
        raw.set_channel_types(eog_map, verbose=False)

    event_id_4_classes = {name: event_id[str(code)] for name, code in TARGET_EVENT_CODES.items()}
    
    # Filtering (8-28Hz) and CAR
    raw_all_ch = raw.copy().pick(['eeg', 'eog'])
    raw_all_ch.filter(l_freq=8.0, h_freq=28.0, fir_design='firwin', verbose=False)
    raw_all_ch.set_eeg_reference(ref_channels='average', projection=True, verbose=False)

    # ICA artifact correction
    ica = ICA(n_components=15, method='fastica', random_state=42)
    ica.fit(raw_all_ch.copy())
    eog_indices, _ = ica.find_bads_eog(raw_all_ch, threshold=3.0, verbose=False)
    ica.exclude = eog_indices
    raw_eeg_cleaned = ica.apply(raw_all_ch.copy()).pick(['eeg'])
    
    return raw_eeg_cleaned, events, event_id_4_classes, raw # Return original raw data for plotting

def extract_fbcsp_features(raw_eeg, events, target_event_id_map):
    """Extracts features using FBCSP-OVR and applies SelectKBest feature selection."""
    print("\n--- 3. FBCSP FEATURE EXTRACTION & SELECTION ---")
    
    target_codes = np.array(list(target_event_id_map.values()))
    
    # 1. Epoching (Time window: 0.5s to 3.0s)
    epochs_base = mne.Epochs(raw_eeg, events, target_event_id_map, tmin=T_MIN, tmax=T_MAX, proj=True, baseline=None, preload=True, verbose=False)
    epochs_base.drop_bad(reject={'eeg': 150e-6}) 
    X_base = epochs_base.get_data()
    y_base = epochs_base.events[:, 2]

    features = []
    
    # 2. FBCSP Feature Generation (8 bands * 4 OVR classes * 8 CSP components)
    for low, high in FREQ_BANDS:
        raw_band = raw_eeg.copy().filter(l_freq=low, h_freq=high, fir_design='firwin', verbose=False)
        epochs_band = mne.Epochs(raw_band, events, target_event_id_map, tmin=T_MIN, tmax=T_MAX, proj=True, baseline=None, preload=True, verbose=False)
        epochs_band.drop_bad(reject={'eeg': 150e-6}) 
        X_band = epochs_band.get_data()
        
        # OVR CSP for each class
        for target_code in target_codes:
            y_ovr = np.where(y_base == target_code, 1, 0)
            csp = CSP(n_components=N_CSP_COMPONENTS, reg='ledoit_wolf', transform_into='average_power', norm_trace=False)
            csp.fit(X_band, y_ovr)
            features.append(csp.transform(X_band))

    X = np.concatenate(features, axis=1)
    print(f"Total Features Generated: {X.shape[1]}")
    
    # 3. Feature Selection (k=32)
    selector = SelectKBest(score_func=mutual_info_classif, k=K_BEST_FEATURES)
    X_selected = selector.fit_transform(X, y_base)
    print(f"Features Selected (k={K_BEST_FEATURES}): {X_selected.shape[1]}")
    
    # 4. Feature Scaling (Standardization)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_selected)
    
    return X_scaled, y_base

def classify_and_evaluate(X_scaled, y_base):
    """Performs cross-validation using ShuffleSplit and generates the Accuracy Plot."""
    print("\n--- 4. CLASSIFICATION (OVR-LinearSVC) ---")
    
    cv_ss = ShuffleSplit(n_splits=N_SHUFFLE_SPLITS, test_size=0.2, random_state=42)
    classifier = OneVsRestClassifier(LinearSVC(random_state=42, multi_class='ovr', dual=False, max_iter=10000))
    
    scores = cross_val_score(classifier, X_scaled, y_base, cv=cv_ss, n_jobs=-1)

    mean_accuracy = np.mean(scores)
    std_accuracy = np.std(scores)
    chance_level = 1.0 / len(TARGET_CLASS_NAMES)

    print(f"Classifier: FBCSP-OVR-LinearSVC")
    print(f"Mean Accuracy: {mean_accuracy:.3f} +/- {std_accuracy:.3f}")
    
    # Visualization of Accuracy (Saves to file)
    print("Generating Final Accuracy Plot...")
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, N_SHUFFLE_SPLITS + 1), scores, marker='o', color='#36A2EB', label='Accuracy per Split')
    plt.axhline(y=mean_accuracy, color='r', linestyle='--', label=f'Mean: {mean_accuracy:.3f}')
    plt.axhline(y=chance_level, color='g', linestyle=':', label=f'Chance Level: {chance_level:.2f}')
    plt.xlabel('Cross-Validation Split Number')
    plt.ylabel('Accuracy')
    plt.title('4-Class FBCSP-OVR-LinearSVC Accuracy (Subject A01T)')
    plt.legend()
    plt.ylim(0, 1.0)
    plt.grid(True)
    plt.savefig('accuracy_plot_A01T_final.png', dpi=300)
    plt.close() # Close plot object
    print("Saved 'accuracy_plot_A01T_final.png'")

    return classifier, scores

def generate_confusion_matrix(classifier, X_scaled, y_base):
    """Generates and plots the cross-validated confusion matrix using StratifiedKFold."""
    print("\n--- 5. CONFUSION MATRIX ANALYSIS & ARTIFACT SAVING ---")
    
    cv_kfold = StratifiedKFold(n_splits=N_KFOLD_SPLITS, shuffle=True, random_state=42)
    y_pred = cross_val_predict(classifier, X_scaled, y_base, cv=cv_kfold, n_jobs=-1)

    cm = confusion_matrix(y_base, y_pred, labels=np.unique(y_base))
    
    fig, ax = plt.subplots(figsize=(8, 8))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=TARGET_CLASS_NAMES)
    disp.plot(cmap=plt.cm.Blues, ax=ax, values_format='d')
    kfold_accuracy = np.mean(y_base == y_pred)
    plt.title(f'Cross-Validated Confusion Matrix (Accuracy {kfold_accuracy:.3f})')
    
    # Save the plot for the publication
    plt.savefig('confusion_matrix_A01T_final.png', dpi=300) 
    plt.close()
    print("Saved 'confusion_matrix_A01T_final.png'")

    # Print Normalized results for academic clarity
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    print("\nNormalized Confusion Matrix (Row: True Class, Column: Predicted Class):")
    print("---------------------------------------------------------------------")
    for i, label in enumerate(TARGET_CLASS_NAMES):
        predictions = ", ".join([f"{TARGET_CLASS_NAMES[j]}: {cm_norm[i, j]:.3f}" for j in range(len(TARGET_CLASS_NAMES))])
        print(f"{label}: {predictions}")

    return kfold_accuracy

def generate_method_visualizations(raw, events, event_id_4_classes):
    """Generates the non-interactive control plots for the Methods section."""
    print("\n--- 6. GENERATING METHOD VISUALIZATIONS ---")
    
    # 1. Raw EEG Plot (for Preprocessing Section)
    raw_eeg_filtered_display = raw.copy().pick_types(eeg=True)
    raw_eeg_filtered_display.filter(l_freq=8.0, h_freq=28.0, fir_design='firwin', verbose=False)

    fig_raw = raw_eeg_filtered_display.plot(
        duration=10.0, n_channels=8, show=False, scalings=dict(eeg=100e-6), 
        title="Raw Filtered Data Sample (Control)"
    )
    fig_raw.savefig('raw_data_sample_control.png', dpi=300)
    plt.close(fig_raw)
    print("Saved 'raw_data_sample_control.png'")

    # 2. Average Epoch Plot (for ERD/ERS demonstration)
    epochs = mne.Epochs(
        raw_eeg_filtered_display.copy(), events, event_id_4_classes, 
        tmin=T_MIN_DEMO, tmax=T_MAX_DEMO, proj=True, baseline=(-0.5, 0), 
        preload=True, verbose=False
    )
    
    epochs['Left_hand'].plot_image(combine='mean', title='Average Epoch: Left Hand (ERD/ERS Demo)', show=False)
    plt.savefig('average_epoch_left_hand.png', dpi=300)
    plt.close()
    print("Saved 'average_epoch_left_hand.png'")


if __name__ == "__main__":
    if not os.path.exists(DATA_PATH_GDF):
        print(f"Error: The data file path was not found: {DATA_PATH_GDF}")
    else:
        try:
            # Step 1: Load and Preprocess Data (includes ICA and CAR)
            raw_eeg_cleaned, events, event_id_4_classes, raw_original = load_and_preprocess_data(DATA_PATH_GDF)
            
            # Step 2: Generate all Method Visualizations (Non-blocking)
            generate_method_visualizations(raw_original, events, event_id_4_classes)
            
            # Step 3: Extract FBCSP Features (includes SelectKBest and Scaling)
            X_scaled, y_base = extract_fbcsp_features(raw_eeg_cleaned, events, event_id_4_classes)
            
            # Step 4: Classify and Evaluate (Uses ShuffleSplit)
            classifier, scores = classify_and_evaluate(X_scaled, y_base)
            
            # Step 5: Save Final Model
            classifier.fit(X_scaled, y_base)
            dump(classifier, 'ovr_svc_model_A01T_final.pkl')
            print("\nModel saved as 'ovr_svc_model_A01T_final.pkl'.")
            
            # Step 6: Generate Confusion Matrix and save artifact
            kfold_accuracy = generate_confusion_matrix(classifier, X_scaled, y_base)
            
            print(f"\nFinal Summary: ShuffleSplit Accuracy (Reported): {np.mean(scores):.3f} +/- {np.std(scores):.3f} | K-Fold Accuracy (Matrix): {kfold_accuracy:.3f}")

        except Exception as e:
            print(f"Critical error during execution: {str(e)}")

--- 2. DATA LOADING & PREPROCESSING ---
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 10.6s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 1 ICA component
    Projecting back using 22 PCA components

--- 6. GENERATING METHOD VISUALIZATIONS ---
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Using matplotlib as 2D backend.
Saved 'raw_data_sample_control.png'
Not setting metadata
72 matching events found
No baseline correction applied
0 projection items activated
combining channels using "mean"
Saved 'average_epoch_left_hand.png'

--- 3. FBCSP FEATURE EXTRACTION & SELECTION ---
0 bad epochs dropped
0 bad epochs dropped
Computing rank fro