In [1]:
import mne
import warnings
import time

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.exceptions import ConvergenceWarning
from sklearn.neural_network import MLPClassifier

from utils.Preprocessing import load_and_preprocess_data

In [2]:
mne.set_log_level("CRITICAL")

# Suppress warnings from LogisticRegression and CSP
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn.linear_model._logistic')
warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=UserWarning, module='mne')

In [3]:
def train_model_pipeline(X_train_flattened, y_train):

    pipeline = Pipeline([
        ('scaler', StandardScaler()),               
        ('pca', PCA(n_components=10)),
        ('mlp', MLPClassifier(hidden_layer_sizes=(500, 50), max_iter=100))
    ])

    model = pipeline.fit(X_train_flattened, y_train)

    return model


In [4]:
# Define subjects and runs for training and testing
train_subjects = [57, 22, 86]
train_tasks = {
    'Task 1:"Open and close fist"': [3, 7, 11],
    'Task 2:"Imagine opening and closing fist"': [4, 8, 12],
    'Task 3:"Open and close fists or feet"': [5, 9, 13],
    'Task 4:"Imagine opening and closing fists or feet"': [6, 10, 14]
}

test_subjects = [34]
test_tasks = {
    'Task 1:"Open and close fist"': [3],
    'Task 2:"Imagine opening and closing fist"': [4],
    'Task 3:"Open and close fists or feet"': [5],
    'Task 4:"Imagine opening and closing fists or feet"': [6]
}

In [5]:
test_subject = [1]
task_name = 'Task 3:"Open and close fists or feet"'
test_run = [5]

In [6]:
train_epochs = load_and_preprocess_data(train_subjects, train_tasks[task_name])

X_train = train_epochs.get_data().astype('float64')
y_train = train_epochs.events[:, -1]

X_train_flattened = X_train.reshape(X_train.shape[0], -1)

model = train_model_pipeline(X_train_flattened, y_train)

In [7]:
test_epochs = load_and_preprocess_data(test_subject, test_run)

X_test = test_epochs.get_data().astype('float64')
y_test = test_epochs.events[:, -1]

X_test_flattened = X_test.reshape(X_test.shape[0], -1)

In [8]:
correct_predictions = 0
total_samples = X_test_flattened.shape[0]

print("Data stream incoming...")
time.sleep(1)
print("Analyzing...")
time.sleep(1)

start_time = time.time()

print(f"{'Epoch':<10}{'Time Stamp':<20}{'Real movement':<15}{'Predicted movement':<20}{'Result':<10}")


for i in range(X_test_flattened.shape[0]):
    single_sample = X_test_flattened[i].reshape(1, -1)
    prediction = model.predict(single_sample)
    real_label = y_test[i]
    
    elapsed_time = time.time() - start_time

    elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
    milliseconds = int((elapsed_time % 1) * 1000)
    elapsed_time_str = f"{elapsed_time_str}.{milliseconds:03d}"

    result = "Success" if prediction[0] == real_label else "Fail"
    if result == "Success":
        correct_predictions += 1
    
    print(f"{i+1:<10}{elapsed_time_str:<20}{'fists' if real_label == 1 else 'feet':<15}{'fists' if prediction[0] == 1 else 'feet':<20}{result:<10}")
    
    time.sleep(1)

print("\nAnalysis complete!")
accuracy = correct_predictions / total_samples * 100
print(f"\nAccuracy: {accuracy:.2f}%")

Data stream incoming...
Analyzing...
Epoch     Time Stamp          Real movement  Predicted movement  Result    
1         00:00:00.000        fists          fists               Success   
2         00:00:01.009        feet           feet                Success   
3         00:00:02.016        fists          fists               Success   
4         00:00:03.020        fists          fists               Success   
5         00:00:04.024        feet           feet                Success   
6         00:00:05.026        fists          feet                Fail      
7         00:00:06.033        fists          fists               Success   
8         00:00:07.041        feet           fists               Fail      
9         00:00:08.042        fists          fists               Success   
10        00:00:09.051        feet           feet                Success   
11        00:00:10.057        fists          fists               Success   
12        00:00:11.061        fists          feet  