In [1]:
import matplotlib.pyplot as plt
import numpy as np
from mne import set_log_level
from mne.io import read_raw_eeglab

from pycrostates.cluster import ModKMeans
from pycrostates.datasets import lemon



In [None]:
#STEP 1 - GEV PER K AND HOVER OVER IMAGES FOR EACH

set_log_level("ERROR")  # reduce verbosity

raw_fname = lemon.data_path(subject_id="010017", condition="EO")
raw = read_raw_eeglab(raw_fname, preload=True)
raw.crop(0, 180)
raw.pick("eeg")
raw.set_eeg_reference("average")

from pycrostates.preprocessing import extract_gfp_peaks

cluster_numbers = [2, 3, 4, 5, 6, 8, 10, 12, 15, 20]

gfp_peaks = extract_gfp_peaks(raw)

gevs = []
image_files = []
for k, n_clusters in enumerate(cluster_numbers):
    # fit K-means algorithm with a set number of cluster centers
    ModK = ModKMeans(n_clusters=n_clusters, random_state=42)
    ModK.fit(gfp_peaks, n_jobs=6, verbose="WARNING")
    gevs.append(ModK.GEV_)
    fig = ModK.plot(show=False)
    image = f'cluster_centers_{k}.png'
    fig.savefig(image)  # save the figure to file
    plt.close(fig)
    image_files.append(image)



In [None]:
import base64
from bokeh.io import output_file, push_notebook, show
from bokeh.models import HoverTool
from bokeh.plotting import figure, ColumnDataSource
from bokeh.io import output_notebook
from bokeh.layouts import row
from bokeh.resources import INLINE

# Prepare the data
x = list(cluster_numbers)
y = gevs

# Load the images and encode them as base64 strings
images = []
for file in image_files:
    with open(file, "rb") as f:
        image_data = f.read()
        encoded_image = base64.b64encode(image_data).decode("utf-8")
        images.append(f"data:image/png;base64,{encoded_image}")

hover = HoverTool(
    tooltips="""
    <div>
        <div>
            <img
                src="@imgs" alt="@imgs"
                style="float: left; margin: 0px 15px 15px 0px;"
                border="2"
            ></img>
        </div>
    </div>
    """
)

# Create a Bokeh ColumnDataSource
source = ColumnDataSource(data=dict(x=x, y=y, imgs=images))

# Create a scatter plot
p = figure(
    title="Scatter Plot with Overlaid Images",
    x_axis_label="X",
    y_axis_label="Y",
    tools=[hover],
)
p.scatter(x="x", y="y", size=10, color="red", source=source)


In [None]:
output_file("scatter_plot.html")

In [None]:
output_notebook(resources=INLINE)  # Enable inline display of the plot
handle = show(row(p), notebook_handle=True)  # Display the plot in a new tab

# Update the plot whenever necessary (e.g., for zooming)
push_notebook(handle=handle)