In [None]:
%pip install mne scikit-learn mne-icalabel torch

In [None]:
# Imports
import os
import pandas as pd
from ipywidgets import *
import numpy as np
import mne
from mne.preprocessing import ICA

# import torch
from mne_icalabel import label_components

import matplotlib.pyplot as plt
from mne.preprocessing import ICA
import scipy

import preprocessing
import utils

# Specify graph rendering method
# %matplotlib widget
plt.switch_backend("TkAgg")

In [None]:
# Currently for 1 patient, will be generalized into a pipeline for all patients

DATASET_PATH = "./dataset"
FILENAME_TEMPLATE = "TMS-EEG-H_02_S1b_{}_{}.vhdr"

rsEEG_pre_raw = mne.io.read_raw_brainvision(
    os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("rsEEG", "pre")), preload=True
)

# rsEEG_post_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("rsEEG", "post")), preload=True)

In [None]:
rsEEG_pre_raw.plot()

In [None]:
clone = rsEEG_pre_raw.copy()

In [None]:
clone = preprocessing.preprocess_rsEEG(clone)

In [None]:
clone.plot()

In [None]:
data = clone.get_data()

widths = np.arange(1, 31)

cwtmatr = [scipy.signal.cwt(channel, scipy.signal.ricker, widths) for channel in data]

In [None]:
cwtmatr

In [None]:
import matplotlib.pyplot as plt

# Assume cwtmatr is the result of the wavelet transform
# Let's visualize the transform for the first channel
channel_number = 0
plt.imshow(
    np.abs(cwtmatr[channel_number]),
    aspect="auto",
    cmap="hot",
    extent=[-1, 1, 1, max(widths)],
)
plt.colorbar(label="Magnitude")
plt.xlabel("Time (s)")
plt.ylabel("Scale")
plt.title("Wavelet transform of channel {}".format(channel_number + 1))
plt.show()

In [None]:
# Compute the wavelet transform of the data
epoch_duration = 5
overlap = 1
sfreq = clone.info["sfreq"]
duration_samples = int(epoch_duration * sfreq)
overlap_samples = int(overlap * sfreq)

onset = np.arange(
    0,
    clone.times[-1] * sfreq - duration_samples,
    duration_samples - overlap_samples,
)

events = np.vstack((onset, np.zeros_like(onset), np.ones_like(onset))).T.astype(int)
event_id = 1
epochs = mne.Epochs(
    clone,
    events,
    event_id=event_id,
    tmin=0,
    tmax=epoch_duration,
    baseline=None,
    preload=True,
)

frequencies = np.arange(1, 100)

wavelets = mne.time_frequency.tfr_morlet(
    epochs, freqs=frequencies, n_cycles=5, n_jobs=-1
)

In [None]:
# INDIVIDUAL PLOTS
bands = {
    "delta": (0.5, 4),
    "theta": (4, 8),
    "alpha": (8, 12),
    "beta": (12, 30),
    "gamma": (30, 50),
}
power = wavelets[0].data
avg_power = np.mean(power, axis=0)

fig, ax = plt.subplots()

# Create the heatmap
cax = ax.imshow(avg_power, aspect="auto", cmap="hot", origin="lower")

# Add a colorbar
fig.colorbar(cax)

# Set the labels for the x and y axes and the title
ax.set_xlabel("Time")
ax.set_ylabel("Frequency")
ax.set_title("Average Power")

# Show the plot
plt.show()

# GROUPED PLOTS
freqs = wavelets[0].freqs

n_rows = len(bands)
n_cols = 1

# Create a figure with multiple subplots
fig, axs = plt.subplots(n_rows, n_cols, figsize=(10, 20))

for ax, (band, (fmin, fmax)) in zip(axs, bands.items()):
    # Find the indices that correspond to this frequency band
    band_indices = np.where((freqs >= fmin) & (freqs <= fmax))[0]

    # Slice the power data to include only these frequencies
    band_power = power[:, band_indices, :]

    # Compute the average power across electrodes
    avg_power = np.mean(band_power, axis=0)

    # Create the heatmap
    cax = ax.imshow(avg_power, aspect="auto", cmap="hot", origin="lower")

    ax.set_xlim([None, 200])

    # Add a colorbar
    fig.colorbar(cax, ax=ax)

    # Set the labels for the x and y axes and the title
    ax.set_xlabel("Time")
    ax.set_ylabel("Frequency")
    ax.set_title(f"Average Power ({band} band)")

# Show the plot
plt.tight_layout()
plt.show()