In [1]:
import mne
mne.set_log_level('error')
import matplotlib
import matplotlib.pyplot as plt
from functions import *
matplotlib.use('TkAgg')

### Loading the data

In [3]:
# data_path = "/Users\sunnivlf\Documents/Data Set Alcohol Detector/Subjects"
data_path = "../Data Set Alcohol Detector/Subjects"
#data_path = "../Data Set Alcohol Detector/Subjects"
p_id = "/" + "sub-P019" + "/" + "ses-S002" + "/" + "sub-P019_ses-S002_task-Default_run-005"
raw = mne.io.read_raw_fif(data_path + "/" + p_id + "_eeg.fif", preload=True)

Plotting the elctrode placement on the scalp

In [6]:
raw.plot_sensors(show_names=True);

### Filtering the data
Highpass filtering is included to remove drifting of the signal, while lowpass filtering will remove power line noise


In [4]:
low_cut = 0.1 
hi_cut  = 30
raw_filt = raw.copy().filter(low_cut, hi_cut)
raw_filt.compute_psd(fmax=125).plot(picks="data", exclude="bads");
raw_filt.plot(highpass=0.1, lowpass=30, duration=200)

<MNEBrowseFigure size 800x800 with 4 Axes>

Choose one channel to use for this experiment. Fp2 is one of the frontal electrodes. The data and time is extracted from the filtered raw data, as numpy arrays.

In [22]:
channels = ['PO8', 'Fp1']

picks = mne.pick_channels(ch_names = raw_filt.info['ch_names'], include=channels)
data, time = raw_filt[picks, :]  

channel_nr = 1

[[ 2.08166817e-17 -1.15323783e-05 -2.07836756e-05 ... -2.90512535e-05
  -1.65760620e-05  1.73472348e-17]
 [ 6.93889390e-18 -7.09227182e-06 -1.26138790e-05 ...  2.15769186e-05
   1.35345392e-05  1.17093835e-17]]


The signal is visulaized using matplotlib.

In [33]:
plt.figure(figsize=(10, 4*len(channels)))
for idx, channel in enumerate(channels):
    ax = plt.subplot(2, 1, idx+1)
    plt.plot(time, data[idx], label=channel)
    plt.ylabel('EEG Amplitude')
    plt.grid(True)
    if idx < len(channels) - 1:
        ax.set_xticklabels([])
    else:
        plt.xlabel('Time [s]')
    plt.legend()
    # plt.ylim([-0.000110, 0.000110])
    
plt.tight_layout()
plt.show()

In [40]:
sample_rate = raw.info['sfreq']
window_size = 8

s1 = 100
s2 = s1 + window_size

delta1 = int(sample_rate*s1)
delta2 = int(sample_rate*s2)

d = data[channel_nr][delta1:delta2]
t = time[delta1:delta2]
num_samples = len(d)

In [41]:
plt.figure(figsize=(10, 4))
plt.plot(t, d)
plt.xlabel('Time (s)')
plt.ylabel('EEG Amplitude')
plt.title(f'EEG Data for {channels[channel_nr]} Channel')
plt.grid(True)
plt.tight_layout()
plt.show()

In [10]:
(freq, fft_freq) = fft(d, num_samples, sample_rate)

In [11]:
f, dt, ssx = plot_stft(d, sample_rate, 50, 5)

In [12]:
imfs = hht(t, d, plot=False)

In [13]:
#cwtm = wt(t, d, sample_rate, w=3.0)

### Plot raw IMFS

In [28]:
filt_imfs = imfs[0:len(imfs)-1]

# Set the figure size dynamically based on the number of IMFs
plt.figure(figsize=(10, 2*len(filt_imfs)))

for idx, imf in enumerate(filt_imfs):
    ax = plt.subplot(len(filt_imfs), 1, idx+1)
    plt.plot(t, imf)
    
    # If it's not the last subplot, remove the x-axis labels to prevent overlap
    if idx < len(filt_imfs) - 1:
        ax.set_xticklabels([])
    else:
        plt.xlabel('Time [s]')  # Only add x-axis label to the bottom subplot

    plt.ylabel('mV')

# Adjust the layout to prevent overlapping
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.subplots_adjust(bottom=0.06)
plt.show()

### Plot Fourier Transform of each IMF

In [27]:
filt_imfs = imfs[0:len(imfs)-1]

# Set the figure size dynamically based on the number of IMFs
plt.figure(figsize=(10, 2*len(filt_imfs)))

for idx, imf in enumerate(filt_imfs):
    # env, inst_freq = ht(t, imf, sample_rate, plot=False)
    freq, fft_res = fft(imf, num_samples, sample_rate, plot=False)
    
    ax = plt.subplot(len(filt_imfs), 1, idx+1)
    plt.plot(freq, fft_res, label=f"IMF {idx+1}")
    plt.xlim(0, 45)
    plt.legend()
    # If it's not the last subplot, remove the x-axis labels to prevent overlap
    if idx < len(filt_imfs) - 1:
        ax.set_xticklabels([])
    else:
        plt.xlabel('Frequency [Hz]')  # Only add x-axis label to the bottom subplot

    plt.ylabel('Amplitude')

# Adjust the layout to prevent overlapping
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.subplots_adjust(bottom=0.06)
plt.show()

### Plot Short-Time Fourier Transform of each IMF

In [16]:
filt_imfs = imfs[0:len(imfs)-1]

# Determine global min and max values for color scale normalization
vmin = np.inf  # start with infinity
vmax = -np.inf  # start with negative infinity

for imf in filt_imfs:
    f, _, Sxx = plot_stft(imf, sample_rate=sample_rate, nperseg=100, noverlap=10, plot=False)
    vmin = min(vmin, np.abs(Sxx).min())  # update vmin
    vmax = max(vmax, np.abs(Sxx).max())  # update vmax

plt.figure(figsize=(16, 4*len(filt_imfs)))

# Create subplots and apply the same vmin and vmax to each
for idx, imf in enumerate(filt_imfs):
    f, dt, Sxx = plot_stft(imf, sample_rate=sample_rate, nperseg=100, noverlap=10, plot=False)
    plt.subplot(len(filt_imfs), 1, idx+1)
    plt.pcolormesh(dt, f, np.abs(Sxx), shading='gouraud', vmin=vmin, vmax=vmax)  # Apply the same vmin and vmax
    plt.ylabel('f [Hz]')
    plt.ylim([0, 50])

# Add common xlabel
plt.xlabel('Time [sec]')

# After creating all subplots, add a single colorbar to the figure
plt.subplots_adjust(right=0.85)  # Adjust subplot to not overlap with colorbar
cbar_ax = plt.gcf().add_axes([0.88, 0.15, 0.025, 0.7])  # Position for the colorbar
plt.colorbar(aspect=10, cax=cbar_ax, label='Magnitude [dB]')

plt.show()

### Plot Wavelet transform

In [18]:
filt_imfs = imfs[0:len(imfs)-1]

plt.figure(figsize=(16, 4*len(filt_imfs)))
# Create subplots and apply the same vmin and vmax to each
for idx, imf in enumerate(filt_imfs):
    cwtm = wt(t, imf, sample_rate, w=6.0)
    ax = plt.subplot(len(filt_imfs), 1, idx+1)
    freq = np.linspace(1, sample_rate/2, 100)
    plt.pcolormesh(t, freq, np.abs(cwtm), cmap='viridis', shading='gouraud')
    if idx < len(filt_imfs) - 1:
        ax.set_xticklabels([])
    else:
        plt.xlabel('Time [s]')  # Only add x-axis label to the bottom subplot
    plt.ylabel('Amplitude')
    plt.ylim([0, 50])

plt.show()

### Plot Hilbert Transform

In [31]:
filt_imfs = imfs[0:len(imfs)-1]

envelopes = []
frequencies = []
for idx, imf in enumerate(filt_imfs):
    amplitude_envelope, instantaneous_frequency = ht(t, imf, sample_rate, plot=False)
    envelopes.append(amplitude_envelope)
    frequencies.append(instantaneous_frequency)

plt.figure(figsize=(16, 4*len(filt_imfs)))
for idx, envelope in enumerate(envelopes):
    ax = plt.subplot(len(filt_imfs), 1, idx+1)
    plt.plot(t, filt_imfs[idx], label=f'IMF {idx}')
    plt.plot(t, envelope)    
    if idx < len(filt_imfs) - 1:
        ax.set_xticklabels([])
    else:
        plt.xlabel('Time [s]')
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.subplots_adjust(bottom=0.06)
plt.show()

plt.figure(figsize=(16, 4*len(filt_imfs)))
for idx, frequency in enumerate(frequencies):
    ax = plt.subplot(len(filt_imfs), 1, idx+1)
    plt.plot(t[1:], frequency, label=f'IMF {idx}')    
    if idx < len(filt_imfs) - 1:
        ax.set_xticklabels([])
    else:
        plt.xlabel('Time [s]')
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.subplots_adjust(bottom=0.06)
plt.show()