<a href="https://colab.research.google.com/github/payaldas30/Impulse-NITK/blob/main/BestModel_RF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_data_path = '/content/drive/MyDrive/Impulse/EEG_Data/train_data'
validation_data_path = '/content/drive/MyDrive/Impulse/EEG_Data/validation_data'
class_folders = {
    "Complex_Partial_Seizures": 0,
    "Electrographic_Seizures": 1,
    "Video_detected_Seizures_with_no_visual_change_over_EEG": 2,
    "Normal": 3
}

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from scipy.fft import fft, fftfreq
from scipy.signal import spectrogram
!pip install PyWavelets
import pywt
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.metrics import classification_report, roc_auc_score, balanced_accuracy_score, roc_curve, auc



In [5]:
def compute_fft_features(signal, sampling_rate):
    fft_values = np.fft.fft(signal)
    fft_magnitudes = np.abs(fft_values[:len(fft_values) // 2])
    freq = np.fft.fftfreq(len(signal), d=1 / sampling_rate)[:len(fft_values) // 2]

    peak_freq = freq[np.argmax(fft_magnitudes)]
    mean_amp = np.mean(fft_magnitudes)

    return peak_freq, mean_amp

def compute_zcr(signal):
    return ((signal[:-1] * signal[1:]) < 0).sum() / len(signal)
def extract_features(data_path, class_folders, sampling_rate=256):
    features = []
    labels = []
    for class_name, class_label in class_folders.items():
        class_folder = os.path.join(data_path, class_name)
        for file in os.listdir(class_folder):
            if file.endswith('.npy'):
                data = np.load(os.path.join(class_folder, file))
                if data.shape[0] > 0:  # Ensure channel 1 exists
                    signal = data[0, :]  # Extract only channel 1
                    peak_freq, mean_amp = compute_fft_features(signal, sampling_rate)
                    zcr = compute_zcr(signal)
                    features.append([peak_freq, mean_amp, zcr])
                    labels.append(class_label)
    return np.array(features), np.array(labels)

# Extract train and validation features
X_train, y_train = extract_features(train_data_path, class_folders)
X_val, y_val = extract_features(validation_data_path, class_folders)

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, roc_auc_score, balanced_accuracy_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt

rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
rf_model.fit(X_train, y_train)

y_pred = rf_model.predict(X_val)
y_prob = rf_model.predict_proba(X_val)
y_val_binarized = label_binarize(y_val, classes=[0, 1, 2, 3])

class_report = classification_report(y_val, y_pred)
balanced_acc = balanced_accuracy_score(y_val, y_pred)
roc_auc = roc_auc_score(y_val_binarized, y_prob, average='macro', multi_class='ovr')

print("Classification Report:\n", class_report)
print(f"Balanced Accuracy: {balanced_acc:.4f}")
print(f"ROC AUC Score: {roc_auc:.4f}")


plt.figure(figsize=(8, 6))
for i in range(4):
    fpr, tpr, _ = roc_curve(y_val_binarized[:, i], y_prob[:, i])
    plt.plot(fpr, tpr, label=f"Class {i} (AUC = {auc(fpr, tpr):.2f})")

plt.title("ROC Curves for Random Forest")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.grid()
plt.show()

Classification Report:
               precision    recall  f1-score   support

           0       0.73      0.76      0.74       549
           1       0.42      0.37      0.39       137
           2       0.81      0.62      0.70        21
           3       0.80      0.80      0.80       696

    accuracy                           0.74      1403
   macro avg       0.69      0.64      0.66      1403
weighted avg       0.73      0.74      0.74      1403

Balanced Accuracy: 0.6363
ROC AUC Score: 0.8808
