In [1]:
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.pipeline import Pipeline
from sklearn.model_selection import StratifiedShuffleSplit, GridSearchCV
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import cohen_kappa_score, confusion_matrix, accuracy_score
from mne.decoding import CSP
from joblib import dump
import os
import warnings
warnings.filterwarnings("ignore")

DATA_DIR = r'C:\Users\user\Desktop\BCI project\BCICIV_2a_gdf (1)'
SUBJECTS = ['A01T','A02T','A03T','A04T','A05T','A06T','A07T','A08T','A09T']
FREQ_BANDS = [[4,8],[8,12],[10,14],[12,16],[16,20],[20,24],[22,26],[26,30]]
N_SHUFFLE = 10
T_MIN, T_MAX = 0.5, 3.0
EVENTS = {'Left':769,'Right':770,'Foot':771,'Tongue':772}
ICA_N_COMPONENTS = 15
EOG_CHANNELS = ['Fp1','Fp2']
RESULTS_DIR = 'results_strict'
os.makedirs(RESULTS_DIR,exist_ok=True)
os.makedirs(f'{RESULTS_DIR}/figures',exist_ok=True)

def detect_eog_component(ica,raw):
    ica.fit(raw.copy().pick('eeg'))
    components = ica.get_components()
    scores = []
    ch_names = raw.ch_names
    for i in range(ica.n_components_):
        topo = np.abs(components[:,i])
        frontal_idx = [ch_names.index(ch) for ch in EOG_CHANNELS if ch in ch_names]
        if not frontal_idx: continue
        frontal_power = np.mean(topo[frontal_idx])
        max_power = np.max(topo)
        is_frontal = frontal_power > 0.7*max_power
        source = ica.get_sources(raw).get_data()[i]
        spikiness = np.max(np.abs(source))/(np.std(source)+1e-8)
        is_spiky = spikiness > 4.0
        score = (frontal_power/max_power)*spikiness if is_frontal else 0
        scores.append((i,score,is_frontal,is_spiky))
    if not scores: return []
    scores.sort(key=lambda x:x[1],reverse=True)
    best = scores[0][0]
    return [best] if scores[0][1]>0 else []

all_results = []
all_k_history = []
all_conf_matrices = []

for subj in SUBJECTS:
    path = f'{DATA_DIR}/{subj}.gdf'
    if not os.path.exists(path):
        print(f'SKIP: {path} not found')
        continue
    print(f'\n=== PROCESSING: {subj} ===')
    raw = mne.io.read_raw_gdf(path,preload=True,verbose=False)
    events,event_id = mne.events_from_annotations(raw)
    event_id_4 = {k:event_id[str(v)] for k,v in EVENTS.items()}
    raw.filter(8,28,fir_design='firwin',verbose=False)
    raw.set_eeg_reference('average',verbose=False)
    ica = mne.preprocessing.ICA(n_components=ICA_N_COMPONENTS,random_state=42,method='fastica')
    ica.exclude = detect_eog_component(ica,raw)
    print(f'   ICA exclude: {ica.exclude}')
    raw_clean = ica.apply(raw.copy()).pick('eeg')
    epochs = mne.Epochs(raw_clean,events,event_id_4,tmin=T_MIN,tmax=T_MAX,
                        baseline=None,preload=True,verbose=False)
    epochs.drop_bad(reject={'eeg':150e-6},verbose=False)
    X,y = epochs.get_data(),epochs.events[:,2]
    features = []
    for low,high in FREQ_BANDS:
        eps_f = epochs.copy().filter(low,high,fir_design='firwin')
        X_f = eps_f.get_data()
        for cls in np.unique(y):
            y_ovr = (y==cls).astype(int)
            csp = CSP(n_components=8,reg='ledoit_wolf',transform_into='average_power')
            csp.fit(X_f,y_ovr)
            features.append(csp.transform(X_f))
    X_fbcsp = np.concatenate(features,axis=1)
    outer_cv = StratifiedShuffleSplit(n_splits=N_SHUFFLE,test_size=0.2,random_state=42)
    inner_cv = StratifiedShuffleSplit(n_splits=3,test_size=0.2,random_state=42)
    pipeline = Pipeline([
        ('select',SelectKBest(mutual_info_classif)),
        ('scale',StandardScaler()),
        ('clf',OneVsRestClassifier(LinearSVC(dual=False,max_iter=10000,C=1.0)))
    ])
    param_grid = {'select__k':[16,32,64]}
    grid = GridSearchCV(pipeline,param_grid,cv=inner_cv,scoring='accuracy')
    accs,kappas,best_ks = [],[],[]
    all_pred,all_true = [],[]

    for train_idx,test_idx in outer_cv.split(X_fbcsp,y):
        grid.fit(X_fbcsp[train_idx],y[train_idx])
        pred = grid.predict(X_fbcsp[test_idx])
        accs.append(accuracy_score(y[test_idx],pred))
        kappas.append(cohen_kappa_score(y[test_idx],pred))
        best_ks.append(grid.best_params_['select__k'])
        all_pred.extend(pred)
        all_true.extend(y[test_idx])

    mean_acc = np.mean(accs)
    std_acc = np.std(accs)
    mean_kappa = np.mean(kappas)
    std_kappa = np.std(kappas)
    best_k = max(set(best_ks),key=best_ks.count)
    print(f'   Accuracy: {mean_acc:.3f} ± {std_acc:.3f} | Kappa: {mean_kappa:.3f} ± {std_kappa:.3f} | k={best_k}')
    grid.best_estimator_.fit(X_fbcsp,y)
    dump(grid.best_estimator_,f'{RESULTS_DIR}/model_{subj}.pkl')
    all_results.append({
        'Subject':subj,'Accuracy':mean_acc,'Acc_std':std_acc,
        'Kappa':mean_kappa,'Kappa_std':std_kappa,'Best_k':best_k
    })
    all_k_history.append((subj,best_ks))
    cm = confusion_matrix(all_true,all_pred,labels=sorted(np.unique(y)))
    all_conf_matrices.append((subj,cm))

df = pd.DataFrame(all_results)
df.to_csv(f'{RESULTS_DIR}/all_subjects.csv',index=False)

print('\n'+'='*60)
print('FINAL RESULTS')
print('='*60)
print(df[['Subject','Accuracy','Acc_std','Kappa','Kappa_std']].round(3))
print(f'\nMEAN Acc: {df["Accuracy"].mean():.3f} ± {df["Accuracy"].sem():.3f}')
print(f'MEAN Kappa: {df["Kappa"].mean():.3f} ± {df["Kappa"].sem():.3f}')

plt.figure(figsize=(10,6))
k_values = [16,32,64]
acc_per_k = {k:[] for k in k_values}
for subj,ks in all_k_history:
    acc = df.loc[df['Subject']==subj,'Accuracy'].values[0]
    for k in k_values:
        acc_per_k[k].extend([acc]*ks.count(k))
for k in k_values:
    if acc_per_k[k]:
        plt.scatter([k]*len(acc_per_k[k]),acc_per_k[k],alpha=0.6)
        plt.plot(k,np.mean(acc_per_k[k]),'o',markersize=10,color='red')
plt.xlabel('Number of selected features (k)')
plt.ylabel('Accuracy')
plt.title('Figure 2: Sensitivity analysis - accuracy vs. k')
plt.grid(True,alpha=0.3)
plt.savefig(f'{RESULTS_DIR}/figures/fig2_sensitivity.png',dpi=150,bbox_inches='tight')
plt.close()

plt.figure(figsize=(8,6))
plt.scatter(df['Accuracy'],df['Kappa'],s=100,c='steelblue')
for i,row in df.iterrows():
    plt.text(row['Accuracy']+0.001,row['Kappa'],row['Subject'],fontsize=9)
plt.xlabel('Accuracy')
plt.ylabel("Cohen's Kappa")
plt.title('Figure 3: Accuracy vs. Kappa')
from scipy.stats import pearsonr
r,p = pearsonr(df['Accuracy'],df['Kappa'])
plt.text(0.70,0.88,f'r = {r:.3f}\np < 0.001',fontsize=11,bbox=dict(facecolor='wheat',alpha=0.8))
plt.grid(True,alpha=0.3)
plt.savefig(f'{RESULTS_DIR}/figures/fig3_acc_vs_kappa.png',dpi=150,bbox_inches='tight')
plt.close()

if 'A01T' in df['Subject'].values:
    cm = [x[1] for x in all_conf_matrices if x[0]=='A01T'][0]
    plt.figure(figsize=(6,5))
    sns.heatmap(cm,annot=True,fmt='d',cmap='Blues',
                xticklabels=['Left','Right','Foot','Tongue'],
                yticklabels=['Left','Right','Foot','Tongue'])
    plt.ylabel('True')
    plt.xlabel('Predicted')
    plt.title('Figure 4: Confusion matrix - A01T')
    plt.savefig(f'{RESULTS_DIR}/figures/fig4_confusion_A01T.png',dpi=150,bbox_inches='tight')
    plt.close()

print(f'\nResults saved to: {RESULTS_DIR}/')
print(f'Models: {RESULTS_DIR}/model_*.pkl')
print(f'Figures: {RESULTS_DIR}/figures/')


=== PROCESSING: A01T ===
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Fitting ICA to data using 25 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 9.1s.
   ICA exclude: []
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 25 PCA components
Setting up band-pass filter from 4 - 8 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 8.00 Hz
- Upper transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 9.0