In [32]:
from scipy.io import loadmat
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from dash import Dash, dcc, html, Input, Output, ctx, no_update
import numpy as np

from import_data_osf2 import get_data

In [34]:
macaque1_data = loadmat("../data/neuromosaics/osfstorage/nhp/Macaque1_M1_181212.mat")
d1 = get_data(is_mat=True, is_macaque=True, mat=macaque1_data, name='Macaque1_M1_181212')

# Unpack the dictionary into variables
for key, value in d1.items():
    globals()[key] = value

In [36]:
# Define some plotting parameters
where_zero = np.where(abs(stimProfile) > 10**(-50))[0][0]
time = np.array([i/fs for i in range(-int(where_zero), evoked_emg.shape[2] - int(where_zero))]) * 1000
fontsize_title = 15
fontsize_axes = 13
fontsize_legend = 10
maps1 = np.copy(maps)

In [38]:
# Function to plot figure

# Function to generate heatmap data
def generate_heatmap_data(iMuscle):
    data_matrix = np.zeros(maps.shape)
    for i in range(ch2xy.shape[0]) :
        data_matrix[int(ch2xy[i][0])][int(ch2xy[i][1])] = sorted_respMean[i][iMuscle]
    return data_matrix  # Generate random heatmap data

In [72]:
def update_detailed_figure_plotly(iArray, iMuscle, d):
    global current_cbar
    
    for key, value in d.items():
        globals()[key] = value
    
    n_repetitions = np.where(stim_channel == iArray + 1)[0].shape[0]
    upLim = np.max(np.mean(sorted_evoked[:, iMuscle, :n_repetitions], 1)) * 1000
    upCounts = max([np.where(stim_channel == i)[0].shape[0] for i in range(ch2xy.shape[0])])
    
    minSorted = min(np.min(sorted_evoked[iArray, iMuscle, :n_repetitions, :]), 
                     np.min(sorted_filtered[iArray, iMuscle, :n_repetitions, :])) * 1000
    maxSorted = max(np.max(sorted_evoked[iArray, iMuscle, :n_repetitions, :]), 
                     np.max(sorted_filtered[iArray, iMuscle, :n_repetitions, :])) * 1000
    
    fig = make_subplots(rows=2, cols=2, subplot_titles=[
        f"{emgs['emgs'][iMuscle]}",
        "Distribution of peak amplitude",
        f"Stack of raw {emgs['emgs'][iMuscle]} EMG",
        ""
    ], horizontal_spacing=0.1, vertical_spacing=0.2)
    
    heatmap = go.Heatmap(z=sorted_evoked[iArray, iMuscle, :n_repetitions, :] * 1000,
                         colorscale='Blues', zmin=minSorted, zmax=maxSorted, 
                         colorbar=dict(len=0.50, x=0.45, y=0.18, yanchor='middle'))
    fig.add_trace(heatmap, row=2, col=1)

    invalid_rows = np.where(sorted_isvalid[iArray, iMuscle] == 0)[0]
    scatter_invalid = go.Scatter(x=np.tile(np.arange(sorted_evoked.shape[-1]), len(invalid_rows)),
                                 y=np.repeat(invalid_rows, sorted_evoked.shape[-1]),
                                 mode='markers', marker=dict(color='red', size=3),
                                 name='▬ Outliers', showlegend=True)
    fig.add_annotation(
        text="<b>▬ Outliers</b>",  # Use <b> for bold
        xref="paper",  # Use paper coordinates (0 to 1)
        yref="paper",
        x=0.6,  # X position (center)
        y=0.35,  # Y position (above the plot)
        showarrow=False,  # No arrow
        font=dict(
            family="Arial",
            size=14,  # Font size (same as legend)
            color="red"  # Red color
        )
    )
    fig.add_trace(scatter_invalid, row=2, col=1)

    mean_color = "#2171b5"  # A mid-tone from 'Blues' colorscale
    m = np.mean(sorted_filtered[iArray, iMuscle, :n_repetitions, :], axis=0) * 1000
    mean_curve = go.Scatter(x=time, y=m, mode='lines', name='Mean EMG', line=dict(color=mean_color))
    fig.add_trace(mean_curve, row=1, col=1)
    
    range_lines = go.Scatter(
        x=[(resp_region[0] - where_zero)/fs * 1000, (resp_region[1] - where_zero)/fs * 1000],
        y=[upLim, upLim],
        mode='lines',
        line=dict(color="orange", dash="dot"),
        name='Range to compute maximum peak'
    )
    fig.add_trace(range_lines, row=1, col=1)

    fig.add_vline(x=(resp_region[0] - where_zero)/fs * 1000, line=dict(color="orange", dash="dot"), row=1, col=1)
    fig.add_vline(x=(resp_region[1] - where_zero)/fs * 1000, line=dict(color="orange", dash="dot"), row=1, col=1)

    hist = go.Histogram(x=sorted_resp[iArray, iMuscle, :n_repetitions] * 1000, nbinsx=10, name='Peak Amplitude')
    fig.add_trace(hist, row=1, col=2)
    
    fig.update_xaxes(title_text='Time (ms)', row=1, col=1)
    fig.update_yaxes(title_text='MEP (mV)', row=1, col=1, range=[0, upLim])
    fig.update_xaxes(title_text='Peak Amplitude', row=1, col=2)
    fig.update_yaxes(title_text='Counts', row=1, col=2, range=[0, upCounts])
    fig.update_xaxes(title_text='Time (ms)', row=2, col=1)
    fig.update_yaxes(title_text='Number of Trials', row=2, col=1)
    
    fig.update_layout( 
        height=800, 
        showlegend=True, 
        legend=dict(x=0.3, y=1.1, orientation='h'),
        plot_bgcolor="white",
        paper_bgcolor="white"
    )
    
    return fig

In [100]:
#| label: figMcqcell

# Define button elements explicitly
buttons = [
    html.Button(emgs['emgsabr'][i], id=f'btn-{i}', n_clicks=0, 
                style={'padding': '10px', 'fontSize': '16px', 'marginBottom': '5px'}) 
    for i in range(n_muscles)
]

app = Dash(__name__)

app.layout = html.Div([
    html.H2("A. Macaque motor response heatmap"),
    html.Div([
        dcc.Graph(id='heatmap', style={'width': '70%', 'height': '600px'}),
        html.Div([
            html.H4("Muscle selection", style={'textAlign': 'center', 'marginBottom': '10px'}),
            html.Div(buttons, style={'display': 'flex', 'flexDirection': 'column', 'marginLeft': '10px'})
        ], style={'display': 'flex', 'flexDirection': 'column', 'marginLeft': '10px'})
    ], style={'display': 'flex', 'alignItems': 'center'}),
    
    # Add this new Graph component for the detailed figure
    html.H2("B. Macaque motor response heatmap"),
    html.Div([dcc.Graph(id='detailed-figure', style={'width': '100%', 'height': '500px'}),
        html.Div(id='dummy-output', style={'display': 'none'})])
])

@app.callback(
    Output('heatmap', 'figure'),
    [Input(f'btn-{i}', 'n_clicks') for i in range(n_muscles)],
    prevent_initial_call=True
)
def update_heatmap(*args):
    if not ctx.triggered:
        return no_update
    
    # Get the index of the clicked button
    button_id = ctx.triggered_id
    if not button_id or not button_id.startswith('btn-'):
        return no_update
        
    index = int(button_id.split('-')[1])
    heatmap_data = generate_heatmap_data(index)

    # Create heatmap figure
    fig = go.Figure()
    
    fig.add_trace(go.Heatmap(
        z=heatmap_data,
        colorscale='Blues',
        colorbar=dict(len=0.75, title="Normalized MEP (mV)", titleside="right", titlefont=dict(size=14)),
        zmin=np.min(heatmap_data),
        zmax=np.max(heatmap_data),
        hoverinfo='x+y+z'  # Enable hover information
    ))

    # Add grey dots for ground electrodes
    zero_coords = np.argwhere(heatmap_data == 0)
    if zero_coords.size > 0:
        fig.add_trace(go.Scatter(
            x=zero_coords[:, 1],
            y=zero_coords[:, 0],
            mode='markers',
            marker=dict(color='grey', size=8, symbol='circle'),
            name="Ground electrode",
            showlegend=True 
        ))
        
    # Layout adjustments
    fig.update_layout(
        clickmode='event+select',  # Important for click events
        autosize=False,
        width=600, height=600,
        xaxis=dict(title="X coordinates", scaleanchor="y", showgrid=False,
                  tickvals=np.arange(heatmap_data.shape[0]),
            ticktext=[str(i) for i in range(heatmap_data.shape[0])]),
        yaxis=dict(
            title="Y coordinates", 
            showgrid=False,
            autorange="reversed",
            range=[0, heatmap_data.shape[0]],
            tickvals=np.arange(heatmap_data.shape[0]),
            ticktext=[str(i) for i in range(heatmap_data.shape[0])]
        ),
        margin=dict(l=50, r=50, t=50, b=50),
        plot_bgcolor="white",
        paper_bgcolor="white"
    )

    return fig

@app.callback(
    Output('detailed-figure', 'figure'),  # Update the detailed figure graph
    [Input('heatmap', 'clickData')],
    [Input(f'btn-{i}', 'n_clicks') for i in range(n_muscles)],
    prevent_initial_call=True
)
def handle_heatmap_click(clickData, *btn_clicks):
    # Determine what triggered the callback
    triggered_id = ctx.triggered_id
    
    # Default empty figure (if no valid click)
    default_fig = go.Figure()
    default_fig.update_layout(title="Click on the heatmap to see details")
    
    # If a heatmap cell was clicked
    if triggered_id == 'heatmap' and clickData:
        try:
            clicked_x = clickData['points'][0]['x']
            clicked_y = clickData['points'][0]['y']
            
            print(f"Heatmap clicked at X: {clicked_x}, Y: {clicked_y}")
            
            # Find the corresponding iArray
            iArray = np.intersect1d(
                np.where(ch2xy[:, 1] == clicked_x)[0], 
                np.where(ch2xy[:, 0] == clicked_y)[0]
            )[0]
            
            # Find which muscle button was last clicked
            muscle_index = None
            for i, clicks in enumerate(btn_clicks):
                if clicks > 0:
                    muscle_index = i
            
            if muscle_index is not None:
                print(f"Updating detailed figure with iArray={iArray}, muscle={muscle_index}")
                # Return the figure instead of just calling the function
                return update_detailed_figure_plotly(iArray, muscle_index, d1)
            else:
                return default_fig
                
        except Exception as e:
            print(f"Error processing click: {e}")
            return default_fig
    
    return default_fig

if __name__ == '__main__':
    app.run(debug=True, port=8051)

In [102]:
hd = generate_heatmap_data(0)
hd.shape

(10, 10)