In [1]:
import matplotlib.pyplot as plt
import numpy as np
from mne import set_log_level
from mne.io import read_raw_eeglab

from pycrostates.cluster import ModKMeans
from pycrostates.datasets import lemon


set_log_level("ERROR")  # reduce verbosity

raw_fname = lemon.data_path(subject_id="010017", condition="EO")
raw = read_raw_eeglab(raw_fname, preload=True)
raw.crop(0, 180)
raw.pick("eeg")
raw.set_eeg_reference("average")

0,1
Measurement date,Unknown
Experimenter,Unknown
Digitized points,64 points
Good channels,61 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,250.00 Hz
Highpass,0.00 Hz
Lowpass,125.00 Hz


In [23]:
from pycrostates.preprocessing import extract_gfp_peaks


cluster_numbers = [2,3,4,5,6]

gfp_peaks = extract_gfp_peaks(raw)

gevs = list()
image_files=list()
labels = []
for k, n_clusters in enumerate(cluster_numbers):
    # fit K-means algorithm with a set number of cluster centers
    ModK = ModKMeans(n_clusters=n_clusters, random_state=42)
    ModK.fit(gfp_peaks, n_jobs=6, verbose="WARNING")
    gevs.append(ModK.GEV_)
    labels.append(ModK._labels_)
    fig = ModK.plot(show=False)
    image = f'cluster_centers_{k}.png'
    fig.savefig(image)   # save the figure to file
    plt.close(fig) 
    image_files.append(image)

In [33]:
from bokeh.io import show
from bokeh.layouts import column
from bokeh.models import ColorBar, ColumnDataSource, HoverTool, Select, CustomJS
from bokeh.palettes import Viridis256
from bokeh.plotting import figure
from bokeh.transform import linear_cmap
import umap

mapper = umap.UMAP().fit(ModK.fitted_data.T)

# Assuming you have the following data columns: x, y, z0, z1, z2
data = {
    'x': mapper.embedding_[:, 0],
    'y': mapper.embedding_[:, 1],
}

color_columns = list()
for k, label in zip(cluster_numbers, labels):
    column_name = f'k_{k}'
    data[column_name] = label
    color_columns.append(column_name)

# Create a ColumnDataSource
source = ColumnDataSource(data=data)

# Create a Select widget for interactive column switching
select = Select(title='Select K Column', options=color_columns, value=color_columns[0])

# Define the color mapper based on the selected column
color_mapper = linear_cmap(field_name=select.value, palette=Viridis256, low=0, high=6)

# Create the scatter plot
plot = figure()
scatter = plot.circle('x', 'y', source=source, fill_color=color_mapper, line_color='black', size=10)

# Add a color bar to show the color mapping
color_bar = ColorBar(color_mapper=color_mapper['transform'], label_standoff=12)
plot.add_layout(color_bar, 'right')

# Add a hover tool to display data on hover
hover_tool = HoverTool(tooltips=[('x', '@x'), ('y', '@y'), (select.value, f'@{select.value}')])
plot.add_tools(hover_tool)

# JavaScript callback function
callback = CustomJS(args=dict(source=source, color_mapper=color_mapper, hover_tool=hover_tool, scatter=scatter, select=select), code="""
    const data = source.data;
    const value = select.value;
    const colorMapper = color_mapper['transform'];

    // Update color mapper field
    colorMapper.field_name = value;

    // Update the fill color property of the scatter glyph
    scatter.glyph.fill_color = {field: value, transform: colorMapper};

    source.change.emit();
""")

# Attach the JavaScript callback to the Select widget
select.js_on_change('value', callback)

# Create a column layout with the Select widget above the plot
layout = column(select, plot)

# Show the layout
show(layout)