<a href="https://colab.research.google.com/github/thatGuyPdeep/BCI_MTP_IITD/blob/main/Updated_LSTM_Model_BCI_Competition_IV_2a.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp /content/drive/MyDrive/BCICIV_2a_gdf.zip /content

In [None]:
%%capture
!unzip /content/BCICIV_2a_gdf.zip

In [None]:
# Install required libraries
!pip install mne
!pip install PyWavelets
!pip install seaborn matplotlib
!pip install keras tensorflow  # Optional if you don't have these installed already



In [None]:
import os
import numpy as np
import mne
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, Bidirectional
from keras.callbacks import EarlyStopping
from scipy.stats import zscore
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import pywt

In [70]:
import os
import glob

# Define the path to your plots folder
plots_folder = '/content/plots/*'  # Adjust this path as necessary

# Delete all files in the plots folder
for file in glob.glob(plots_folder):
    os.remove(file)

print("All plots deleted successfully.")


All plots deleted successfully.


In [71]:
# Function to read and preprocess data
def read_path(path):
    raw = mne.io.read_raw_gdf(path, eog=['EOG-left', 'EOG-central', 'EOG-right'], preload=True)
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])
    raw.set_eeg_reference()
    events = mne.events_from_annotations(raw)
    epoch = mne.Epochs(raw, events[0], event_id=[7, 8, 9, 10], on_missing='warn')
    labels = epoch.events[:, -1]
    features = epoch.get_data()

    # Normalize features
    features = (features - np.mean(features, axis=0)) / np.std(features, axis=0)

    return labels, features

In [72]:
# List of dataset files (Ensure these files are uploaded to Colab or accessible)
dataset_files = [
    "A01E.gdf", "A01T.gdf", "A02E.gdf", "A02T.gdf",
    "A03E.gdf", "A03T.gdf", "A04E.gdf", "A04T.gdf",
    "A05E.gdf", "A05T.gdf", "A06E.gdf", "A06T.gdf",
    "A07E.gdf", "A07T.gdf", "A08E.gdf", "A08T.gdf",
    "A09E.gdf", "A09T.gdf"
]

In [73]:
# Create output directory
output_dir = "EEG_Analysis_Results"
os.makedirs(output_dir, exist_ok=True)

In [74]:
# Function to preprocess and read EEG data
def read_and_preprocess(path):
    raw = mne.io.read_raw_gdf(path, preload=True)
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])
    raw.set_eeg_reference()
    events = mne.events_from_annotations(raw)[0]
    epoch = mne.Epochs(raw, events, event_id=[7, 8, 9, 10], on_missing='warn')
    labels = epoch.events[:, -1]
    features = epoch.get_data()

    # Normalize features
    features = zscore(features, axis=0)
    return labels, features, raw

In [84]:
def plot_visualizations(dataset_name, raw, features, labels, events):
    # Average the features across epochs
    if features.ndim == 3:  # Check if features have three dimensions
        features_avg = np.mean(features, axis=0)  # Shape will be (n_channels, n_samples)
        print("Averaged features shape:", features_avg.shape)
    else:
        features_avg = features

    # 1. PSD Plot
    plt.figure()
    raw.plot_psd(fmin=1, fmax=40, show=False)
    plt.savefig(f'{dataset_name}_psd.png')
    plt.close()

    # 2. Evoked Plot
    evoked = mne.EvokedArray(features_avg, raw.info, tmin=0)
    plt.figure()
    evoked.plot(show=False)
    plt.savefig(f'{dataset_name}_evoked.png')
    plt.close()

    print("Visualizations saved for dataset:", dataset_name)


In [83]:
# Main loop to process each dataset
for dataset in dataset_files:
    # Read and preprocess the dataset
    labels, features, raw = read_and_preprocess(dataset)

    events, _ = mne.events_from_annotations(raw)

    # Plot visualizations
    plot_visualizations(dataset.split('.')[0], raw, features, labels, events)

    # Split data into training and test sets
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

    # One-hot encode labels
    encoder = OneHotEncoder(sparse_output=False)
    y_train_encoded = encoder.fit_transform(y_train.reshape(-1, 1))
    y_test_encoded = encoder.transform(y_test.reshape(-1, 1))

    # Reshape for LSTM (samples, timesteps, features)
    X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], X_train.shape[2]))
    X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], X_test.shape[2]))

    # Build the LSTM model
    model = Sequential()
    model.add(Bidirectional(LSTM(64, return_sequences=True), input_shape=(X_train.shape[1], X_train.shape[2])))
    model.add(Dropout(0.5))
    model.add(Bidirectional(LSTM(64)))
    model.add(Dropout(0.5))
    model.add(Dense(y_train_encoded.shape[1], activation='softmax'))

    # Compile the model
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    # Fit the model
    history = model.fit(X_train, y_train_encoded, validation_split=0.2, epochs=100, batch_size=32,
                        callbacks=[EarlyStopping(monitor='val_loss', patience=10)])

    # Plot Accuracy and Loss Curves
    plt.figure(figsize=(12, 4))
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title(f'Accuracy Curves - {dataset.split(".")[0]}')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_accuracy.png'))
    plt.close()

    plt.figure(figsize=(12, 4))
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title(f'Loss Curves - {dataset.split(".")[0]}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_loss.png'))
    plt.close()

    # Evaluate the model
    test_loss, test_accuracy = model.evaluate(X_test, y_test_encoded)
    print(f"Test Accuracy for {dataset.split('.')[0]}: {test_accuracy * 100:.2f}%")

    # Confusion Matrix
    y_pred = model.predict(X_test)
    y_pred_classes = np.argmax(y_pred, axis=1)
    cm = confusion_matrix(y_test, y_pred_classes)

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=encoder.categories_[0], yticklabels=encoder.categories_[0])
    plt.title(f'Confusion Matrix - {dataset.split(".")[0]}')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_confusion_matrix.png'))
    plt.close()

    # ROC Curve
    fpr, tpr, thresholds = roc_curve(y_test_encoded.ravel(), y_pred.ravel())
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.title(f'ROC Curve - {dataset.split(".")[0]}')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_roc_curve.png'))
    plt.close()

    # Precision-Recall Curve
    precision, recall, _ = precision_recall_curve(y_test_encoded.ravel(), y_pred.ravel())

    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision)
    plt.title(f'Precision-Recall Curve - {dataset.split(".")[0]}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_precision_recall_curve.png'))
    plt.close()

    # Feature Distribution Visualization
    plt.figure(figsize=(12, 4))
    sns.histplot(features.flatten(), bins=30)
    plt.title(f'Feature Distribution - {dataset.split(".")[0]}')
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_feature_distribution.png'))
    plt.close()

    # Reshape the features for boxplot
    features_avg = features.mean(axis=-1)  # shape will be (288, 22)

    # Convert to DataFrame for plotting
    features_df = pd.DataFrame(features_avg)

    # Boxplots of Features
    plt.figure(figsize=(12, 4))
    sns.boxplot(data=features_df)
    plt.title(f'Boxplot of Features - {dataset.split(".")[0]}')
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_boxplot.png'))
    plt.show()

    # Average across the last dimension (e.g., time dimension)
    features_avg = features.mean(axis=-1)  # Now shape will be (288, 22)

    # Correlation Heatmap
    plt.figure(figsize=(12, 8))
    # Calculate correlation matrix for the averaged features
    correlation_matrix = np.corrcoef(features_avg, rowvar=False)
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', square=True)
    plt.title(f'Correlation Heatmap - {dataset.split(".")[0]}')
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_correlation_heatmap.png'))
    plt.show()

    # Dimensionality Reduction Visualization with PCA
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(features.reshape(features.shape[0], -1))

    plt.figure(figsize=(12, 8))
    plt.scatter(pca_result[:, 0], pca_result[:, 1], c=labels, cmap='viridis')
    plt.title(f'PCA - {dataset.split(".")[0]}')
    plt.xlabel('PCA Component 1')
    plt.ylabel('PCA Component 2')
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_pca.png'))
    plt.close()

    # Dimensionality Reduction Visualization with t-SNE
    tsne = TSNE(n_components=2, perplexity=30)
    tsne_result = tsne.fit_transform(features.reshape(features.shape[0], -1))

    plt.figure(figsize=(12, 8))
    plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=labels, cmap='viridis')
    plt.title(f't-SNE - {dataset.split(".")[0]}')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_tsne.png'))
    plt.close()

    # Training Progress Plots
    plt.figure(figsize=(12, 4))
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title(f'Accuracy vs. Epochs - {dataset.split(".")[0]}')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_training_accuracy.png'))
    plt.close()

    plt.figure(figsize=(12, 4))
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title(f'Loss vs. Epochs - {dataset.split(".")[0]}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'{dataset.split(".")[0]}_training_loss.png'))
    plt.close()

# End of processing for each dataset

Extracting EDF parameters from /content/A01E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
Not setting metadata
288 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 176 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']


  epoch = mne.Epochs(raw, events, event_id=[7, 8, 9, 10], on_missing='warn')
  epoch = mne.Epochs(raw, events, event_id=[7, 8, 9, 10], on_missing='warn')
  epoch = mne.Epochs(raw, events, event_id=[7, 8, 9, 10], on_missing='warn')
  super().__init__(**kwargs)


Epoch 1/100


  return self.fn(y_true, y_pred, **self._fn_kwargs)


[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 268ms/step - accuracy: 1.0000 - loss: 0.0000e+00 - val_accuracy: 1.0000 - val_loss: 0.0000e+00
Epoch 2/100
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 73ms/step - accuracy: 1.0000 - loss: 0.0000e+00 - val_accuracy: 1.0000 - val_loss: 0.0000e+00
Epoch 3/100
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 73ms/step - accuracy: 1.0000 - loss: 0.0000e+00 - val_accuracy: 1.0000 - val_loss: 0.0000e+00
Epoch 4/100
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 75ms/step - accuracy: 1.0000 - loss: 0.0000e+00 - val_accuracy: 1.0000 - val_loss: 0.0000e+00
Epoch 5/100
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step - accuracy: 1.0000 - loss: 0.0000e+00 - val_accuracy: 1.0000 - val_loss: 0.0000e+00
Epoch 6/100
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 80ms/step - accuracy: 1.0000 - loss: 0.0000e+00 - val_accuracy: 1.0000 - val_loss: 0.0000e+00
Epo

  raw.plot_psd(fmin=1, fmax=40, show=False)


ValueError: Info (22) and data (288) must have same number of channels.

<Figure size 640x480 with 0 Axes>

In [None]:
!zip -r plots.zip plots/
!ls
from google.colab import files
files.download('plots.zip')

updating: plots/ (stored 0%)
updating: plots/A04E_training_loss.png (deflated 23%)
updating: plots/A07E_accuracy.png (deflated 22%)
updating: plots/A03E_correlation_heatmap.png (deflated 3%)
updating: plots/A02T_correlation_heatmap.png (deflated 3%)
updating: plots/A04E_correlation_heatmap.png (deflated 3%)
updating: plots/A08E_confusion_matrix.png (deflated 21%)
updating: plots/A07T_accuracy.png (deflated 11%)
updating: plots/A01T_boxplot.png (deflated 14%)
updating: plots/A09E_boxplot.png (deflated 13%)
updating: plots/A08E_roc_curve.png (deflated 16%)
updating: plots/A04T_roc_curve.png (deflated 16%)
updating: plots/A01E_loss.png (deflated 24%)
updating: plots/A09T_confusion_matrix.png (deflated 20%)
updating: plots/A07E_pca.png (deflated 9%)
updating: plots/A08T_loss.png (deflated 11%)
updating: plots/A03E_roc_curve.png (deflated 17%)
updating: plots/A05T_pca.png (deflated 12%)
updating: plots/A04E_pca.png (deflated 7%)
updating: plots/A04E_tsne.png (deflated 6%)
updating: plots/A0

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>