## Import all required library

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import mne
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pywt 
from PIL import Image
from datetime import datetime, timezone
from utility import *
from model import *

## Load the classification model

### Load the Diffusion Classifier (slower but more accurate)

In [2]:
b_ = 5
t_ = 20
MODEL_FILE_DIRC_WaveGrad = MODEL_FILE_DIRC + "/WaveGrad"
MODEL_FILE_DIRC_DC       = MODEL_FILE_DIRC + "/DC"
MODEL_FILE_DIRC_DC_bt = MODEL_FILE_DIRC_DC + f"_b{b_}_t{t_}"

# Load the pretrained diffusion model
diffusion_model = WaveGradNN(config).to(device)
if os.path.exists(f"{MODEL_FILE_DIRC_WaveGrad}/Advanced_Diffusion_best.pt"):
    state_dict_loaded    = torch.load(f"{MODEL_FILE_DIRC_WaveGrad}/Advanced_Diffusion_best.pt")
    diffusion_model.load_state_dict(state_dict_loaded["model"])
else:
    raise("No pretrained diffusion model exists")


# Create Diffusion Classifier, turn to evaluation mode, and Load the best model
model = DiffusionClassifier(config, diffusion_model, device, t_, b_).to(device)
model.eval()
state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC_DC_bt}/DC_best.pt")

### Load the CNN (faster but less accurate)

In [2]:
MODEL_FILE_DIRC_CNN       = MODEL_FILE_DIRC + "/CNN_synthesized"

model      = CNN().to(device)
model.eval()
state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC_CNN}/CNN_best.pt")
model.load_state_dict(state_dict_loaded["model"])

<All keys matched successfully>

## Predict the result

In [3]:
filename = f"Sample_Data/Export-#1495_2664_2022-08-15_21-01-13.eeg"
raw = Compumedics(filename).export_to_mne_raw() # Return a raw object
annotation_df = raw.annotations.to_data_frame()

# Show all the spike
annotation_df.loc[["spike" in des.lower() for des in annotation_df["description"]]] 

Start to import c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\Sample_Data\Export-#1495_2664_2022-08-15_21-01-13.eeg
Checking necessary files...
Found .sdy at c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\Sample_Data\Export-#1495_2664_2022-08-15_21-01-13.eeg\export-#1495_2664_2022-08-15_21-01-13.sdy
Checking the content of c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\Sample_Data\Export-#1495_2664_2022-08-15_21-01-13.eeg\eegdata
Found EEG header at c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\Sample_Data\Export-#1495_2664_2022-08-15_21-01-13.eeg\eegdata\eegdata.ini
Opening all (7) .rda file(s)
Checking optional files...
Checking electrode placement file...
Checking event database...
Event database ok
Reading Compumedics header (.sdy) file
Compumedics header loaded
33 channel(s): ['EMG2', 'ECGL', 'ECGR', 'HV', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz', 'T1', 'T2', 'EOGL', 'EOGR', 'EMG1', 'Fp1', 'F

  raw.set_annotations(ann)


Raw: <RawArray | 33 x 7227504 (28232.4 s), ~1.78 GB, data loaded>
Conversion complete


Unnamed: 0,onset,duration,description


In [4]:
SECONDS_TO_TRASH = 10
EEG_file_list = ["Sample_Data/Export-#1495_2664_2022-08-15_21-01-13.eeg"]
for eeg_filename in EEG_file_list:
    num = 1000
    
    ## Process the edf file
    process_edf(eeg_filename, num, 
                SKIP_FIRST = SECONDS_TO_TRASH, 
                SKIP_LAST  = SECONDS_TO_TRASH)    
    
    ## Get the data
    datasets, _, datasets_DWT = get_dataloader([num], get_dataloader=False, shuffle=False)
    datasets = torch.cat(datasets, dim=0).to(torch.float32) 
    datasets = datasets.to(device)
    datasets_DWT = torch.cat(datasets_DWT, dim=0).to(torch.float32) 
    datasets_DWT = datasets_DWT.to(device)
    
    ## Classify the existance of spike
    if isinstance(model,DiffusionClassifier):
        print("Using Diffusion Classifier...")
        dataloader = DataLoader(dataset = Dataset_Class(datasets, datasets_DWT, torch.zeros(len(datasets_DWT), dtype=torch.int8)), 
                                    batch_size = BATCH_SIZE, shuffle = False, num_workers=1)
        outputs = []
        for epoch_data, epoch_data_DWT, label in dataloader:
            epoch_data_DWT = epoch_data_DWT.to(device)
            epoch_data     = epoch_data.to(device)
            
            output    = model(epoch_data_DWT, epoch_data)
            outputs.append(torch.round(torch.sigmoid(output)).detach().flatten().cpu().numpy())
        outputs=np.concatenate(outputs, axis=0)
    elif isinstance(model,CNN):
        print("Using CNN...")
        dataloader = DataLoader(dataset = Dataset_Class1(datasets, torch.zeros(len(datasets), dtype=torch.int8)), 
                                    batch_size = BATCH_SIZE, shuffle = False, num_workers=1)
        outputs = []
        for epoch_data, _ in dataloader:
            epoch_data = epoch_data.to(device)
            output    = model(epoch_data)
            
            outputs.append(torch.round(torch.sigmoid(output)).detach().flatten().cpu().numpy())
        outputs=np.concatenate(outputs, axis=0)

    num_spike = int(outputs.sum())
    print(f"There is total {num_spike} spike detected in {eeg_filename}")
    print(outputs, "\n")

For the file:  Sample_Data/Export-#1495_2664_2022-08-15_21-01-13.eeg
Start to import c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\Sample_Data\Export-#1495_2664_2022-08-15_21-01-13.eeg
Checking necessary files...
Found .sdy at c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\Sample_Data\Export-#1495_2664_2022-08-15_21-01-13.eeg\export-#1495_2664_2022-08-15_21-01-13.sdy
Checking the content of c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\Sample_Data\Export-#1495_2664_2022-08-15_21-01-13.eeg\eegdata
Found EEG header at c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\Sample_Data\Export-#1495_2664_2022-08-15_21-01-13.eeg\eegdata\eegdata.ini
Opening all (7) .rda file(s)
Checking optional files...
Checking electrode placement file...
Checking event database...
Event database ok
Reading Compumedics header (.sdy) file
Compumedics header loaded
33 channel(s): ['EMG2', 'ECGL', 'ECGR', 'HV', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 

  raw.set_annotations(ann)


Raw: <RawArray | 33 x 7227504 (28232.4 s), ~1.78 GB, data loaded>
Conversion complete
Before drop some row, shape = (7227504, 34)
After  drop some row, shape = (7222384, 34)
Setting up band-pass filter from 0.5 - 70 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 70.00 Hz
- Upper transition bandwidth: 17.50 Hz (-6 dB cutoff frequency: 78.75 Hz)
- Filter length: 1691 samples (6.605 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    3.4s


Creating RawArray with float64 data, n_channels=19, n_times=7222384
    Range : 0 ... 7222383 =      0.000 ... 28212.434 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Sample_Data/Export-#1495_2664_2022-08-15_21-01-13.eeg does not contain NULL value, the process will continue
The dataframe of Sample_Data/Export-#1495_2664_2022-08-15_21-01-13.eeg have been saved to EEG_csv/eeg1000.csv

The data from EEG_csv/eeg1000.csv is loaded 
There is no spike in this eeg file
(1411, 1280, 19)
EEG1000 has 1411 windows of data 


> > > Train    data  has shape: torch.Size([987, 19, 1280]) when duration = 10 seconds
> > > Data after DWT has shape: torch.Size([987, 19, 1282])
> > > Label              shape: torch.Size([987])
There is total 0 spike detected in Sample_Data/Export-#1495_2664_2022-08-15_21-01-13.eeg
[0. 0. 0. ... 0. 0. 0.] 



In [5]:
if num_spike >= 1:
    raw = Compumedics(eeg_filename).export_to_mne_raw() # Return a raw object
    
    # Get the mne Annotations object
    annotations = mne.Annotations(onset= [SECONDS_TO_TRASH+5+i*DURATION for i, out in enumerate(output) if out==1],
                                    duration=[0.0] * num_spike,
                                    description=["Spike"] * num_spike,
                                    orig_time=raw.info["meas_date"])

    # Print the annotations to verify
    print(annotations.to_data_frame())
    
    # Failed to save again to edf, have problem on it 
    # Set the annotations
    raw.set_annotations(annotations)
    new_edf_filename = eeg_filename[:-4] + "_Processed.eeg"
    
    mne.export.export_raw(new_edf_filename, raw, fmt="eeglab")
    print(f"The label file have been saved from {eeg_filename} to {new_edf_filename}")

In [7]:
for i, has_spike in enumerate(outputs):
    if has_spike:
        print(f"At {i*DURATION:<4}s, has spike? {has_spike}")