In [1]:
from scipy.io import loadmat
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from dash import Dash, dcc, html, Input, Output, ctx, no_update
import numpy as np
import pickle
from import_data_osf2 import get_data

In [2]:
is_mat = False
with open("../data/neuromosaics/osfstorage/rat1/rat1_C5_500uA.pkl", "rb") as f:
            data = pickle.load(f)
d1 = get_data(is_mat=is_mat, is_macaque=False, mat=data, name='rat1_C5_500uA')

# Unpack the dictionary into variables
for key, value in d1.items():
    globals()[key] = value

# Define some plotting parameters
where_zero = np.where(abs(stimProfile) > 10**(-50))[0][0]
time = np.array([i/fs for i in range(-int(where_zero), evoked_emg.shape[2] - int(where_zero))]) * 1000
fontsize_title = 15
fontsize_axes = 13
fontsize_legend = 10
maps1 = np.copy(maps)

In [3]:
# Function to generate heatmap data
def generate_heatmap_data(iMuscle):
    data_matrix = np.zeros(maps.shape)
    for i in range(ch2xy.shape[0]) :
        data_matrix[int(ch2xy[i][0])][int(ch2xy[i][1])] = sorted_respMean[i][iMuscle]
    return data_matrix  # Generate random heatmap data

In [4]:
def update_detailed_figure_plotly(iArray, iMuscle, d):
    global current_cbar
    
    for key, value in d.items():
        globals()[key] = value
    
    n_repetitions = np.where(stim_channel == iArray + 1)[0].shape[0]
    upLim = np.max(np.mean(sorted_evoked[:, iMuscle, :n_repetitions], 1)) * 1000
    upCounts = max([np.where(stim_channel == i)[0].shape[0] for i in range(ch2xy.shape[0])])
    
    minSorted = min(np.min(sorted_evoked[iArray, iMuscle, :n_repetitions, :]), 
                     np.min(sorted_filtered[iArray, iMuscle, :n_repetitions, :])) * 1000
    maxSorted = max(np.max(sorted_evoked[iArray, iMuscle, :n_repetitions, :]), 
                     np.max(sorted_filtered[iArray, iMuscle, :n_repetitions, :])) * 1000
    
    fig = make_subplots(rows=2, cols=2, subplot_titles=[
        f"Stack of raw {emgs['emgs'][iMuscle]} EMG",
        f"Stack of filtered {emgs['emgs'][iMuscle]} EMG",
        f"{emgs['emgs'][iMuscle]}",
        "Distribution of peak amplitude"
        ""
    ], horizontal_spacing=0.1, vertical_spacing=0.2)
    
    heatmap = go.Heatmap(z=sorted_evoked[iArray, iMuscle, :n_repetitions, :] * 1000,
                         colorscale='Blues', zmin=minSorted, zmax=maxSorted, 
                         colorbar=dict(len=0.4, x=1, y=0.8, yanchor='middle'))
    fig.add_trace(heatmap, row=1, col=1)

    invalid_rows = np.where(sorted_isvalid[iArray, iMuscle] == 0)[0]
    scatter_invalid = go.Scatter(x=np.tile(np.arange(sorted_evoked.shape[-1]), len(invalid_rows)),
                                 y=np.repeat(invalid_rows, sorted_evoked.shape[-1]),
                                 mode='markers', marker=dict(color='red', size=3),
                                 name='Outliers', showlegend=True)

    heatmap2 = go.Heatmap(z=sorted_filtered[iArray, iMuscle, :n_repetitions, :] * 1000,
                         colorscale='Blues', zmin=minSorted, zmax=maxSorted, 
                         colorbar=dict(len=0.4, x=1, y=0.8, yanchor='middle'))
    fig.add_trace(heatmap2, row=1, col=2)

    fig.add_trace(scatter_invalid, row=1, col=1)

    mean_color = "#2171b5"  # A mid-tone from 'Blues' colorscale
    m = np.mean(sorted_filtered[iArray, iMuscle, :n_repetitions, :], axis=0) * 1000
    mean_curve = go.Scatter(x=time, y=m, mode='lines', name='Mean EMG', line=dict(color=mean_color))
    fig.add_trace(mean_curve, row=2, col=1)
    
    range_lines = go.Scatter(
        x=[(resp_region[0] - where_zero)/fs * 1000, (resp_region[1] - where_zero)/fs * 1000],
        y=[upLim, upLim],
        mode='lines',
        line=dict(color="orange", dash="dot"),
        name='Range to compute maximum peak'
    )
    fig.add_trace(range_lines, row=2, col=1)

    fig.add_vline(x=(resp_region[0] - where_zero)/fs * 1000, line=dict(color="orange", dash="dot"), row=2, col=1)
    fig.add_vline(x=(resp_region[1] - where_zero)/fs * 1000, line=dict(color="orange", dash="dot"), row=2, col=1)

    hist = go.Histogram(x=sorted_resp[iArray, iMuscle, :n_repetitions] * 1000, nbinsx=10, name='Peak Amplitude')
    fig.add_trace(hist, row=2, col=2)
    
    fig.update_xaxes(title_text='Time (ms)', row=2, col=1)
    fig.update_yaxes(title_text='MEP (mV)', row=2, col=1, range=[0, upLim])
    fig.update_xaxes(title_text='Peak Amplitude', row=2, col=2)
    fig.update_yaxes(title_text='Counts', row=2, col=2, range=[0, upCounts])
    fig.update_xaxes(title_text='Time (ms)', row=1, col=1)
    fig.update_yaxes(title_text='Number of Trials', row=1, col=1)
    fig.update_xaxes(title_text='Time (ms)', row=1, col=2)
    fig.update_yaxes(title_text='Number of Trials', row=1, col=2)
    
    fig.update_layout( 
        height=800, 
        showlegend=True, 
        legend=dict(x=0, y=0.5, orientation='h'),
        plot_bgcolor="white",
        paper_bgcolor="white"
    )
    
    return fig

In [9]:
heatmap_data = generate_heatmap_data(0)
if not is_mat:
        if additional_info['Position_Line_0'].item() == 'Left':
            heatmap_data = np.rot90(heatmap_data, 1)
            maps = np.rot90(maps1, 1)
        if additional_info['Position_Line_0'].item() == 'Right':
            heatmap_data = np.rot90(heatmap_data, 3)
            maps = np.rot90(maps1, 3)
            
        for ch in range(int(parameters['nChan'])):
            x, y = np.where(maps == ch + 1)[0][0], np.where(maps == ch + 1)[1][0]
            ch2xy[ch] = [x, y]

  for ch in range(int(parameters['nChan'])):


In [11]:
def plot_boussole(fig) :

    # Add some sample data so we can see the plot
    #fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='markers'))
    
    # Arrow parameters
    compass_x, compass_y = 0.1, 0.95  # Starting position (10% from left, 90% from bottom)
    arrow_length = 50  # Length in pixels (more visible than small fractional values)
    
    # Left arrow (now using pixel units for deltas)
    fig.add_annotation(
        x=compass_x - 0.11, y=compass_y,
        ax=arrow_length, ay=0,  # Negative x for left direction
        xref="paper", yref="paper",
        axref="pixel", ayref="pixel",  # Using pixel units for arrow length
        showarrow=True,
        arrowhead=2,  # Arrowhead style (1-5)
        arrowsize=1.5,  # Larger arrowhead
        arrowwidth=2,  # Thicker line
        arrowcolor="black"
    )
    
    # Rostral arrow (upwards)
    fig.add_annotation(
        x=compass_x, y=compass_y + 0.08,
        ax=0, ay=arrow_length,  # Positive y for upward direction
        xref="paper", yref="paper",
        axref="pixel", ayref="pixel",
        showarrow=True,
        arrowhead=2,
        arrowsize=1.5,
        arrowwidth=2,
        arrowcolor="black"
    )
    
    # Text labels
    fig.add_annotation(
        x=compass_x - 0.11 - 0.04, y=compass_y,
        text="Left",
        xref="paper", yref="paper",
        showarrow=False,
        font=dict(color="black", size=12),
        xanchor="center",
        yanchor="middle"
    )
    
    fig.add_annotation(
        x=compass_x, y=compass_y +  0.08 + 0.02,
        text="Rostral",
        xref="paper", yref="paper",
        showarrow=False,
        font=dict(color="black", size=12),
        xanchor="center",
        yanchor="middle"
    )
    
    fig.update_layout(
        margin=dict(l=50, r=50, t=50, b=50)  # Ensure margins don't hide annotations
    )

In [13]:
#| label: figRat1cell

# Define button elements explicitly
buttons = [
    html.Button(emgs['emgsabr'][i], id=f'btn-{i}', n_clicks=0, 
                style={'padding': '10px', 'fontSize': '16px', 'marginBottom': '5px'}) 
    for i in range(n_muscles)
]

app = Dash(__name__)

app.layout = html.Div([
    html.H2("A. Rat motor response heatmap"),
    html.Div([
        dcc.Graph(id='heatmap', style={'width': '70%', 'height': '600px'}),
        html.Div([
            html.H4("Muscle selection", style={'textAlign': 'center', 'marginBottom': '10px'}),
            html.Div(buttons, style={'display': 'flex', 'flexDirection': 'column', 'marginLeft': '10px'})
        ], style={'display': 'flex', 'flexDirection': 'column', 'marginLeft': '10px'})
    ], style={'display': 'flex', 'alignItems': 'center'}),
    
    # Add this new Graph component for the detailed figure
    html.H2("B. Rat motor response heatmap"),
    html.Div([dcc.Graph(id='detailed-figure', style={'width': '100%', 'height': '500px'}),
        html.Div(id='dummy-output', style={'display': 'none'})])
])

@app.callback(
    Output('heatmap', 'figure'),
    [Input(f'btn-{i}', 'n_clicks') for i in range(n_muscles)],
    prevent_initial_call=True
)
def update_heatmap(*args):
    if not ctx.triggered:
        return no_update
    
    # Get the index of the clicked button
    button_id = ctx.triggered_id
    if not button_id or not button_id.startswith('btn-'):
        return no_update
        
    index = int(button_id.split('-')[1])
    heatmap_data = generate_heatmap_data(index)
    # Create heatmap figure
    fig = go.Figure()
    plot_boussole(fig)
    fig.add_trace(go.Heatmap(
        z=heatmap_data,
        colorscale='Blues',
        colorbar=dict(len=0.75, title="Normalized MEP (mV)", titleside="right", titlefont=dict(size=14)),
        zmin=np.min(heatmap_data),
        zmax=np.max(heatmap_data),
        hoverinfo='x+y+z'  # Enable hover information
    ))

    # Add grey dots for ground electrodes
    zero_coords = np.argwhere(heatmap_data == 0)
    if zero_coords.size > 0:
        fig.add_trace(go.Scatter(
            x=zero_coords[:, 1],
            y=zero_coords[:, 0],
            mode='markers',
            marker=dict(color='grey', size=8, symbol='circle'),
            name="Ground electrode",
            showlegend=True 
        ))

    # Layout adjustments
    fig.update_layout(
        clickmode='event+select',  # Important for click events
        autosize=False,
        width=600, height=600,
        xaxis=dict(title="X coordinates", scaleanchor="y", showgrid=False,
                  tickvals=np.arange(heatmap_data.shape[0]),
            ticktext=[str(i) for i in range(heatmap_data.shape[0])]),
        yaxis=dict(
            title="Y coordinates", 
            showgrid=False,
            autorange="reversed",
            range=[0, heatmap_data.shape[0]],
            tickvals=np.arange(heatmap_data.shape[0]),
            ticktext=[str(i) for i in range(heatmap_data.shape[0])]
        ),
        margin=dict(l=50, r=50, t=50, b=50),
        plot_bgcolor="white",
        paper_bgcolor="white"
    )

    return fig

@app.callback(
    Output('detailed-figure', 'figure'),  # Update the detailed figure graph
    [Input('heatmap', 'clickData')],
    [Input(f'btn-{i}', 'n_clicks') for i in range(n_muscles)],
    prevent_initial_call=True
)
def handle_heatmap_click(clickData, *btn_clicks):
    # Determine what triggered the callback
    triggered_id = ctx.triggered_id
    
    # Default empty figure (if no valid click)
    default_fig = go.Figure()
    default_fig.update_layout(title="Click on the heatmap to see details")
    
    # If a heatmap cell was clicked
    if triggered_id == 'heatmap' and clickData:
        try:
            clicked_x = clickData['points'][0]['x']
            clicked_y = clickData['points'][0]['y']
            
            print(f"Heatmap clicked at X: {clicked_x}, Y: {clicked_y}")
            
            # Find the corresponding iArray
            iArray = np.intersect1d(
                np.where(ch2xy[:, 1] == clicked_x)[0], 
                np.where(ch2xy[:, 0] == clicked_y)[0]
            )[0]
            
            # Find which muscle button was last clicked
            muscle_index = None
            for i, clicks in enumerate(btn_clicks):
                if clicks > 0:
                    muscle_index = i
            
            if muscle_index is not None:
                print(f"Updating detailed figure with iArray={iArray}, muscle={muscle_index}")
                # Return the figure instead of just calling the function
                return update_detailed_figure_plotly(iArray, muscle_index, d1)
            else:
                return default_fig
                
        except Exception as e:
            print(f"Error processing click: {e}")
            return default_fig
    
    return default_fig

if __name__ == '__main__':
    app.run(debug=True, port=8052)