In [None]:
#Look Ma a Spike


In [None]:
import math
import mne
from MEG_Tools import MEG
import matplotlib as plt
import numpy as np
plt.use('Qt5Agg')
meg = MEG('case_2225_with_spike_dipoles_sleep_2.mat')
mne_fif = meg.get_mne()

In [None]:
#First we find the location of a known spike


In [None]:
known_spike_time = meg.spikes[0]['begin']
known_spike_time = int(math.floor(known_spike_time))#round to second below
known_spike_time

In [None]:
def round_down_to_nearest_epoch(time, epoch_length):
    # Rounds the time down to the nearest multiple of epoch_length
    return int(math.floor(time / epoch_length))


In [None]:
#Make epochs
length_of_epoch = 1
sfreq = mne_fif.info['sfreq']
#events = np.array([[i, 0, 1] for i in range(0, mne_fif.n_times - int(sfreq), int(sfreq))], dtype=int)
events = np.array([[i, 0, 1] for i in range(0, mne_fif.n_times - int(length_of_epoch * sfreq), int(length_of_epoch * sfreq))], dtype=int)

#epochs = mne.Epochs(mne_fif, events, event_id={'one_sec': 1}, tmin=0, tmax=1, baseline=None, preload=True)
epochs = mne.Epochs(mne_fif, events, event_id={str(length_of_epoch)+" sec": 1}, tmin=0, tmax=3, baseline=None, preload=True)


In [None]:
#Let's take a look....
first_spike = epochs[round_down_to_nearest_epoch(known_spike_time,length_of_epoch)]

# Plot the first epoch
first_spike.plot()

In [None]:
#plot - Wavelet Spike
frequencies = np.arange(4, 70, 1)
spike_time = round_down_to_nearest_epoch(known_spike_time,length_of_epoch)
power = mne.time_frequency.tfr_morlet(
    epochs[spike_time-1:spike_time+1], n_cycles=2, return_itc=False, freqs=frequencies, decim=3
)
power.plot(["MEG0242"])

In [None]:
#plot - Wavelet Non-Spike
frequencies = np.arange(4, 70, 1)
power = mne.time_frequency.tfr_morlet(
    epochs[length_of_epoch*3:length_of_epoch*4], n_cycles=2, return_itc=False, freqs=frequencies, decim=3
)
power.plot(["MEG0242"])

In [None]:
#plot
epochs[spike_time-1:spike_time+1].plot_image(picks=["MEG0242"])

In [None]:
def get_spiked_lists(epochs_,meg_):
    epochlength = round_down_to_nearest_epoch(len(epochs_), length_of_epoch)
    # Creating spiked_list from meg.spikes
    #spiked_list_ = []
    #for n in range(len(meg_.spikes)):
    #    begin = meg_.spikes[n]['begin']
    #    if (begin <= epochlength):
    #        spiked_list_.append(int(math.floor(begin)))
    spiked_list_ = []
    for spike in meg_.spikes:
        begin = spike['begin']
        rounded_begin = round_down_to_nearest_epoch(begin, length_of_epoch)
        spiked_list_.append(rounded_begin)
    
    # Remove duplicates and sort spiked_list
    spiked_list_ = sorted(list(set(spiked_list_)))
    
    # Creating unspiked_list
    unspiked_list_ = [i for i in range(epochlength) if i not in spiked_list_]
    
    spiked_epochs_ = epochs_[spiked_list_]
    unspiked_epochs_ = epochs_[unspiked_list_]
    return (spiked_list_,unspiked_list_,spiked_epochs_,unspiked_epochs_)



In [None]:
(spiked_list,unspiked_list,spiked_epochs, unspiked_epochs) = get_spiked_lists(epochs,meg)

In [None]:
#Plot
spiked_evoked = spiked_epochs.average()
unspiked_evoked = unspiked_epochs.average()

mne.viz.plot_compare_evokeds(
    dict(spiked=spiked_evoked, unspiked=unspiked_evoked),
    legend="upper left",
    show_sensors="upper right",
)

In [None]:
#Plot
evoked_diff = mne.combine_evoked([spiked_evoked, unspiked_evoked], weights=[1, -1])
evoked_diff.pick(picks="mag").plot_topo(color="r", legend=False)

In [None]:
#Plot - Hmmmm
spiked_evoked.plot_joint(picks="mag")
spiked_evoked.plot_topomap(times=[0.0, 0.08, 0.1, 0.12, 0.2], ch_type="mag")

## Clustering

In [None]:
import functools
import time
import pywt
from scipy.stats import skew, kurtosis  # Importing skew and kurtosis

def timer(func):
    @functools.wraps(func)
    def wrapper_timer(*args, **kwargs):
        tic = time.perf_counter()
        value = func(*args, **kwargs)
        toc = time.perf_counter()
        elapsed_time = toc - tic
        print(f"Elapsed time: {elapsed_time:0.4f} seconds")
        return value
    return wrapper_timer

def get_labels(spiked_list_,unspiked_list_):
    # Total number of epochs (assuming the highest value from either list is the last epoch)
    total_epochs = max(spiked_list_ + unspiked_list_) + 1  # +1 because lists are 0-indexed
    
    # Initialize a list with zeros (assuming all epochs are initially unspiked)
    labels_ = [0] * total_epochs
    
    # Set the spiked epochs to 1
    for index in spiked_list_:
        labels_[index] = 1
    return labels_

@timer
def get_features(epochs_):
    
    features_ = []
    #mne.set_log_level('WARNING')
    
    for n in range(len(epochs_)):
        # Progress indicator
        perc = np.round(100 * n / len(epochs_), 1)
        print(f"Progress: {perc}%", end='\r', flush=True)

        data = epochs_[n].get_data(copy=False)[0]

        # Apply Discrete Wavelet Transform
        coeffs = pywt.wavedec(data, wavelet='db4', level=4)  # Example: 4-level decomposition

        epoch_features = []
        # Calculate features for each coefficient level
        for coeff in coeffs:
            # Here, ensure that each feature extracted is a single value and not an array
            energy = np.sum(coeff**2)
            mean_coeff = np.mean(coeff)
            var_coeff = np.var(coeff)
            skew_coeff = skew(coeff)
            kurt_coeff = kurtosis(coeff)

            # Append each feature to epoch_features
            epoch_features.extend([energy, mean_coeff, var_coeff, skew_coeff, kurt_coeff])

        # Append the feature set for this epoch to features_
        features_.append(epoch_features)
    return features_
print('')


In [None]:

features = get_features(epochs)
labels = get_labels(spiked_list,unspiked_list)


In [None]:
def flatten_mixed_list(mixed_list):
    flattened_list = []
    for item in mixed_list:
        if isinstance(item, list) or isinstance(item, np.ndarray):
            # If the item is a list or numpy array, extend the flattened list with the flattened item
            flattened_list.extend(flatten_mixed_list(item))
        else:
            # If the item is not a list, just append it to the flattened list
            flattened_list.append(item)
    return flattened_list


for n in range(len(features)):
    features[n] = flatten_mixed_list(features[n])

In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

spiked_labels = labels.copy()
n_clusters = 6 
# Apply K-means Clustering with clusters
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init = 'auto').fit(features)
cluster_labels = kmeans.labels_

# Reduce the feature dimensions to 2D using PCA
pca = PCA(n_components=2)
reduced_features = pca.fit_transform(features)

# Scatter plot of the seven clusters
plt.figure(figsize=(10, 6))
for cluster in range(n_clusters):
    mask = (cluster_labels == cluster)
    plt.scatter(reduced_features[mask, 0], reduced_features[mask, 1], label=f'Cluster {cluster}')

# Optionally, plot the centroids
centroids = pca.transform(kmeans.cluster_centers_)
plt.scatter(centroids[:, 0], centroids[:, 1], marker='X', s=200, color='k', label='Centroids')

plt.xlabel('PCA Feature 1')
plt.ylabel('PCA Feature 2')
plt.title(f'2D Visualization of {n_clusters} Clusters')
#plt.legend()
plt.show()

# Counting spiked instances in each cluster
for cluster in range(n_clusters):
    # Counting spiked instances in the current cluster
    this_cluster_spiked_count = sum(spiked_labels[n] for n in range(len(spiked_labels)) if cluster_labels[n] == cluster)
    print(f"Cluster {cluster} contains {this_cluster_spiked_count} spiked instances")

