In [None]:
import dash
from dash import html, dcc
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
from plotly.colors import sample_colorscale
from plotly.subplots import make_subplots
import scipy.io
import pandas as pd
import numpy as np
from PIL import Image
import base64, io  # for uploads

stimStart, stimEnd = 1.0, 3.0

generalpath = '/Volumes/gonzo/Victoria/data/layer4/m118/2025-04-30_rf/'

# --- Initial sample data load ---
data_mat = scipy.io.loadmat(
    generalpath + 'data.mat'
    #r'Y:\Victoria\data\layer4\m118\2025-04-30_rf\data.mat'
)['data'][0,0]
normalized = scipy.io.loadmat(
    generalpath + 'm118_normalizedTraces.mat'
    #r'Y:\Victoria\data\layer4\m118\2025-04-30_rf\m118_normalizedTraces.mat'
)['normalizedTraces']
frame_rate = float(data_mat['frame_rate'][0,0])

rois_df = pd.read_csv(
    generalpath + 'roi_coords.csv',
    #r'Y:\Victoria\data\layer4\m118\2025-04-30_rf\roi_coords.csv',
    header=None, names=['roi','x','y']
)
n_rois = len(rois_df)

img = Image.open(
    generalpath + 'STD_MED_moco.png'
    #r'Y:\Victoria\data\layer4\m118\2025-04-30_rf\STD_MED_moco.png'
)
img_array = np.array(img)

# Precompute tracesAll + time_vector
def compute_traces():
    global tracesAll, time_vector, n_grid_x, n_grid_y
    n_grid_x, n_grid_y = 8, 5
    n_time = normalized.shape[1]
    mean_traces = normalized.mean(axis=0)
    tracesAll = np.zeros((n_grid_y, n_grid_x, n_time, n_rois))
    for roi in range(n_rois):
        stim_by_time = mean_traces[:, :, roi].T
        tracesAll[:, :, :, roi] = stim_by_time.reshape(n_grid_y, n_grid_x, n_time)
    time_vector = np.arange(n_time) / frame_rate

compute_traces()

# --- Build base figure template once ---
# Dimensions must match how tracesAll is reshaped
n_grid_x, n_grid_y = 8, 5

base_fig = make_subplots(
    rows=n_grid_y, cols=n_grid_x,
    horizontal_spacing=0.005,
    vertical_spacing=0.005,
)
for r in range(n_grid_y):
    for c in range(n_grid_x):
        # Empty line trace
        base_fig.add_trace(
            go.Scatter(x=[], y=[], mode='lines', line=dict(width=1)),
            row=r+1, col=c+1
        )
        # Gray stimulus rectangle
        base_fig.add_vrect(
            x0=stimStart, x1=stimEnd,
            fillcolor="LightGray", opacity=0.5,
            line_width=0, layer="below",
            row=r+1, col=c+1
        )

# Hide everything except outer ticks
base_fig.update_xaxes(showticklabels=False, showgrid=False, showline=False,
                      zeroline=True, zerolinewidth=1, zerolinecolor='black')
base_fig.update_yaxes(showticklabels=False, showgrid=False, showline=False,
                      zeroline=True, zerolinewidth=1, zerolinecolor='black')
# Bottom row x ticks
for c in range(1, n_grid_x+1):
    base_fig.update_xaxes(showticklabels=True, ticks="outside",
                          row=n_grid_y, col=c)
# Left col y ticks
for r in range(1, n_grid_y+1):
    base_fig.update_yaxes(showticklabels=True, ticks="outside",
                          row=r, col=1)

base_fig.update_layout(
    plot_bgcolor='rgba(0,0,0,0)',
    paper_bgcolor='rgba(0,0,0,0)',
    showlegend=False,
    margin=dict(t=30,b=30,l=30,r=30)
)

# --- RF window ---
idx0 = int(stimStart * frame_rate)
idx1 = int(stimEnd * frame_rate)

# --- Precompute summary means and build a reusable summary template ---
summary_means = np.stack([tracesAll[:,:,:,i].mean(axis=(0,1)) for i in range(n_rois)])
grand_mean    = summary_means.mean(axis=0)

summary_template = go.Figure()
for m in summary_means:
    summary_template.add_trace(go.Scattergl(
        x=time_vector, y=m,
        mode='lines',
        line=dict(color='gray', width=1),
        hoverinfo='none'
    ))
# Placeholder for current ROI highlight
summary_template.add_trace(go.Scattergl(
    x=time_vector, y=summary_means[0],
    mode='lines',
    line=dict(color='red', width=2),
    hoverinfo='none'
))
# Grand mean
summary_template.add_trace(go.Scattergl(
    x=time_vector, y=grand_mean,
    mode='lines',
    line=dict(color='black', width=4),
    hoverinfo='none'
))
summary_template.add_vrect(
    x0=stimStart, x1=stimEnd,
    fillcolor="LightGray", opacity=0.5,
    line_width=0, layer="below"
)
summary_template.update_layout(
    title="Mean Across All Grids",
    title_x=0.5,
    plot_bgcolor='rgba(0,0,0,0)',
    paper_bgcolor='rgba(0,0,0,0)',
    xaxis_title="Time (s)",
    yaxis_title="dF/F",
    margin=dict(t=40, b=40, l=40, r=20),
    showlegend=False
)

# --- Precompute RF center-of-mass for each ROI (only positive responses) ---
com_x = np.zeros(n_rois)
com_y = np.zeros(n_rois)

# Build index grids once
Y_idx, X_idx = np.indices((n_grid_y, n_grid_x))

for i in range(n_rois):
    # Get the raw 5×8 heatmap for ROI i
    raw_z = tracesAll[:, :, idx0:idx1, i].mean(axis=2)
    # Zero-out any negative excursions
    zpos  = np.clip(raw_z, 0, None)
    # Flip to match heatmap orientation
    zflip = np.flipud(zpos)

    mass = zflip.sum()
    if mass > 0:
        com_x[i] = (zflip * X_idx).sum() / mass
        com_y[i] = (zflip * Y_idx).sum() / mass
    else:
        com_x[i] = np.nan
        com_y[i] = np.nan

# --- Precompute both RF peak and two flavors of COM ---
peak_x   = np.zeros(n_rois)
peak_y   = np.zeros(n_rois)
com_raw_x = np.zeros(n_rois)
com_raw_y = np.zeros(n_rois)
com_bs_x  = np.zeros(n_rois)
com_bs_y  = np.zeros(n_rois)

# Index grids
Y_idx, X_idx = np.indices((n_grid_y, n_grid_x))

for i in range(n_rois):
    raw_z = tracesAll[:, :, idx0:idx1, i].mean(axis=2)

    # 1) Peak location on the *flipped* map
    zpos    = raw_z - raw_z.min()              # baseline-subtract
    zflip   = np.flipud(zpos)
    flat    = np.argmax(zflip)
    pr, pc  = np.unravel_index(flat, zflip.shape)
    peak_y[i], peak_x[i] = pr, pc

    # 2) COM on *raw* clipped-to-zero
    zclip   = np.clip(raw_z, 0, None)
    zclipf  = np.flipud(zclip)
    m1      = zclipf.sum()
    if m1>0:
        com_raw_x[i] = (zclipf*X_idx).sum()/m1
        com_raw_y[i] = (zclipf*Y_idx).sum()/m1
    else:
        com_raw_x[i] = com_raw_y[i] = np.nan

    # 3) COM on baseline-subtracted
    m2      = zflip.sum()
    if m2>0:
        com_bs_x[i] = (zflip*X_idx).sum()/m2
        com_bs_y[i] = (zflip*Y_idx).sum()/m2
    else:
        com_bs_x[i] = com_bs_y[i] = np.nan


# --- Build image figure ---
fig_img = go.Figure()
fig_img.add_layout_image(dict(
    source=img, xref="x", yref="y",
    x=0, y=img_array.shape[0],
    sizex=img_array.shape[1], sizey=img_array.shape[0],
    sizing="stretch", layer="below"
))
fig_img.update_xaxes(visible=False, autorange=True)
fig_img.update_yaxes(visible=False, autorange=True,
                     scaleanchor="x", scaleratio=1)
fig_img.update_layout(
    clickmode='event+select',
    margin=dict(l=0, r=0, t=0, b=0),
    autosize=True,
    plot_bgcolor="white",
    paper_bgcolor="white"
)
fig_img.add_trace(go.Scatter(
    x=rois_df['x'],
    y=img_array.shape[0] - rois_df['y'],
    mode='markers+text',
    marker=dict(size=6, color='red', line=dict(width=1)),
    text=rois_df['roi'], textposition='top center',
    hovertemplate="ROI %{text}<br>x: %{x}<br>y: %{y}<extra></extra>"
))

# --- Dash App & Layout ---
app = dash.Dash(__name__)

app.layout = html.Div(style={'display':'flex','flexDirection':'column','height':'100vh','width':'100vw'}, children=[
    # --- Title ---
    html.H1("Victoria's RF Explorer", style={'textAlign':'center','margin':'10px 0'}),
    
    # --- Navigation buttons ---
    html.Div(style={'display':'flex','justifyContent':'center','gap':'20px','padding':'10px'}, children=[
        html.Button('← Previous ROI', id='prev-roi', n_clicks=0),
        html.Button('Next ROI →',      id='next-roi', n_clicks=0),
    ]),

    # --- Controls row ---
    html.Div(style={'display':'flex','justifyContent':'space-between','alignItems':'center','padding':'0 20px'}, children=[
        html.Div(style={'display':'flex','gap':'10px','alignItems':'center'}, children=[
            dcc.Upload(id='upload-data', children=html.Button('📂 Load Your Files'), multiple=True),
            html.Span(id='upload-status'),
            html.Button('🔄 Load Sample Data', id='load-sample-btn', n_clicks=0),
            html.Span(id='sample-status'),
        ]),
        dcc.Checklist(
            id='view-options',
            options=[
                {'label': 'Show Individual Trials',               'value': 'show_trials'},
                {'label': 'Show Heatmap',                         'value': 'show_heatmap'}
            ],
            value=[], inline=True, inputStyle={"margin":"0 10px"}
        )
    ]),

    # Store current ROI index
    dcc.Store(id='current-roi', data=0),

    # --- Main content ---
    html.Div(style={'display':'flex','flex':'1','overflow':'hidden','gap':'10px'}, children=[
        # Left: 2p image
        html.Div(style={'flex':'1','display':'flex','flexDirection':'column','minWidth':0}, children=[
            html.H2("Select ROI", style={'textAlign':'center','margin':'5px 0'}),
            dcc.Graph(id='rf-image', figure=fig_img, style={'flex':'1','height':'100%','width':'100%'}, config={'responsive':True})
        ]),

        # Right: Traces + Summary grid
        html.Div(style={
                'flex':'1',
                'display':'grid',
                'height':'100%',
                'gridTemplateRows':'2fr 1fr',
                'rowGap':'0px',
                'minWidth':0
            }, children=[
            # RF traces panel
            html.Div(style={'display':'flex','flexDirection':'column','height':'100%'}, children=[
                html.H2("ROI 1 RF Traces", id='trace-title', style={'textAlign':'center','margin':'2px 0'}),
                dcc.Loading(id='loading-traces', type='circle', delay_show=750, children=[
                    dcc.Graph(id='rf-traces', config={'responsive':True}, style={'flex':'1 1 auto','width':'100%'})
                ])
            ]),
            # Summary panel
            html.Div(style={'display':'flex','flexDirection':'column','height':'100%'}, children=[
            # Remove top margin so it hugs the panel above
            html.H2("Summary Metrics", style={'textAlign':'center','margin':'0 0 4px 0'}),
            html.Div(style={'display':'flex','flex':'1','gap':'10px'}, children=[
                # :eft: mean timecourse
                dcc.Graph(id='rf-summary', figure=summary_template, config={'responsive':True},
                          style={'flex':'1','height':'100%'}),
                # Right: COM scatter
                dcc.Graph(id='rf-com-summary',                        
                          config={'responsive':True},
                          style={'flex':'1','height':'100%'})
            ])
        ])
        ])
    ])
])

# --- Callbacks ---
@app.callback(
    Output('upload-status','children'),
    Input('upload-data','contents'),
    State('upload-data','filename')
)
def load_upload(contents_list, filenames):
    if not contents_list or not filenames:
        raise dash.exceptions.PreventUpdate
    file_map = {n.lower():c for n,c in zip(filenames, contents_list)}
    missing = []
    if 'data.mat' not in file_map:          missing.append('data.mat')
    if not any('normalized' in n for n in file_map): missing.append('normalizedTraces.mat')
    if not any(n.endswith('roi_coords.csv') for n in file_map): missing.append('roi_coords.csv')
    if not any(n.endswith(('.png','jpg','jpeg')) for n in file_map): missing.append('background image (.png)')
    if missing:
        return html.Span(f"❌ Missing: {', '.join(missing)}", style={'color':'red'})
    global data_mat, normalized, frame_rate, rois_df, n_rois, img, img_array
    for name, content in file_map.items():
        decoded = base64.b64decode(content.split(',')[1])
        if name=='data.mat':
            m = scipy.io.loadmat(io.BytesIO(decoded))
            data_mat = m['data'][0,0]
            frame_rate = float(data_mat['frame_rate'][0,0])
        elif 'normalized' in name:
            normalized = scipy.io.loadmat(io.BytesIO(decoded))['normalizedTraces']
        elif name.endswith('roi_coords.csv'):
            rois_df = pd.read_csv(io.StringIO(decoded.decode()), header=None, names=['roi','x','y'])
            n_rois = len(rois_df)
        else:
            img = Image.open(io.BytesIO(decoded))
            img_array = np.array(img)
    compute_traces()
    return html.Span("✅ Files loaded.", style={'color':'green'})

@app.callback(
    Output('sample-status','children'),
    Input('load-sample-btn','n_clicks')
)
def load_sample(n_clicks):
    if not n_clicks:
        raise dash.exceptions.PreventUpdate
    global data_mat, normalized, frame_rate, rois_df, n_rois, img, img_array
    data_mat = scipy.io.loadmat(r'Y:\Victoria\data\layer4\m118\2025-04-30_rf\data.mat')['data'][0,0]
    normalized = scipy.io.loadmat(r'Y:\Victoria\data\layer4\m118\2025-04-30_rf\m118_normalizedTraces.mat')['normalizedTraces']
    frame_rate = float(data_mat['frame_rate'][0,0])
    rois_df = pd.read_csv(r'Y:\Victoria\data\layer4\m118\2025-04-30_rf\roi_coords.csv', header=None, names=['roi','x','y'])
    n_rois = len(rois_df)
    img = Image.open(r'Y:\Victoria\data\layer4\m118\2025-04-30_rf\STD_MED_moco.png')
    img_array = np.array(img)
    compute_traces()
    return "✅ Sample data loaded."

@app.callback(
    Output('rf-image','figure'),
    Input('upload-data','contents'),
    Input('load-sample-btn','n_clicks')
)
def update_rf_image(_, __):
    fig = go.Figure()
    fig.add_layout_image(dict(source=img, xref="x", yref="y",
                              x=0, y=img_array.shape[0],
                              sizex=img_array.shape[1], sizey=img_array.shape[0],
                              sizing="stretch", layer="below"))
    fig.update_xaxes(visible=False, autorange=True)
    fig.update_yaxes(visible=False, autorange=True, scaleanchor="x", scaleratio=1)
    fig.update_layout(clickmode='event+select', margin=dict(l=0,r=0,t=0,b=0),
                      autosize=True, plot_bgcolor="white", paper_bgcolor="white")
    fig.add_trace(go.Scatter(
        x=rois_df['x'], y=img_array.shape[0]-rois_df['y'],
        mode='markers+text',
        marker=dict(size=6, color='red', line=dict(width=1)),
        text=rois_df['roi'], textposition='top center',
        hovertemplate="ROI %{text}<br>x: %{x}<br>y: %{y}<extra></extra>"
    ))
    return fig

@app.callback(
    Output('current-roi','data'),
    [
        Input('prev-roi','n_clicks'),
        Input('next-roi','n_clicks'),
        Input('rf-image','clickData'),
    ],
    State('current-roi','data')
)
def navigate(prev_clicks, next_clicks, clickData, current):
    ctx = dash.callback_context
    if not ctx.triggered:
        return current
    prop = ctx.triggered[0]['prop_id'].split('.')[0]
    if prop == 'prev-roi':
        return max(0, current-1)
    if prop == 'next-roi':
        return min(n_rois-1, current+1)
    if prop == 'rf-image' and clickData:
        pt = clickData['points'][0]
        cx, cy = pt['x'], img_array.shape[0] - pt['y']
        coords = rois_df[['x','y']].to_numpy()
        return int(np.argmin(np.hypot(coords[:,0]-cx, coords[:,1]-cy)))
    return current

@app.callback(
    Output('rf-traces','figure'),
    Output('rf-summary','figure'),
    Output('rf-com-summary','figure'),
    Output('trace-title','children'),
    Input('view-options','value'),
    Input('current-roi','data'),
)
def update_traces(view_opts, current_roi):
    # Flags
    show_trials  = 'show_trials'  in (view_opts or [])
    show_heatmap = 'show_heatmap' in (view_opts or [])

    # 1) Update the summary timecourse highlight
    summary_fig = summary_template
    summary_fig.data[n_rois].y = summary_means[current_roi]
        
    # 2) Build the always-visible COM scatter
    # Make a list of gray colors… but highlight the selected ROI in red
    colors = ['gray'] * n_rois
    colors[current_roi] = 'red'

    com_fig = go.Figure()
    com_fig.add_trace(go.Scatter(
        x=com_x, y=com_y,
        mode='markers+text',
        marker=dict(
            size=8,
            color=colors,               # use per-point color list
            line=dict(width=1, color=colors)
        ),
        text=[str(i+1) for i in range(n_rois)],
        textposition='top center',
        hovertemplate="ROI %{text}<br>x: %{x:.2f}<br>y: %{y:.2f}<extra></extra>"
    ))
    com_fig.update_layout(
    title="RF COM For All ROIs",
    title_x=0.5,
    xaxis=dict(
        title="Grid X",
        range=[-0.5, n_grid_x-0.5],   # from half‐cell before first to half‐cell after last
        tick0=0,                      # start ticks at 0
        dtick=1,                      # one tick (and grid line) per integer
        showgrid=True,
        gridcolor='lightgrey',
        gridwidth=1,
        zeroline=False
    ),
    yaxis=dict(
        title="Grid Y",
        range=[-0.5, n_grid_y-0.5],
        tick0=0,
        dtick=1,
        showgrid=True,
        gridcolor='lightgrey',
        gridwidth=1,
        zeroline=False
    ),
    plot_bgcolor='rgba(0,0,0,0)',
    paper_bgcolor='rgba(0,0,0,0)',
    margin=dict(t=40, b=20, l=40, r=20)
)

    # 3) Set trace title
    base_title = f"ROI {current_roi+1}"

    # 4) Branch on show_heatmap, or default…
    # HEATMAP view
    if show_heatmap:
        grid = tracesAll[:,:,:,current_roi]
        z = grid[:,:,idx0:idx1].mean(axis=2).reshape(5,8)
        z = np.flipud(z)
        fig = go.Figure(go.Heatmap(z=z, colorscale='Viridis', colorbar=dict(title='Mean dF/F')))
        
        # Peak location (yellow circle)
        fig.add_trace(go.Scatter(
            x=[peak_x[current_roi]], y=[peak_y[current_roi]],
            mode='markers',
            marker=dict(symbol='circle-open', size=14, line=dict(width=2, color='yellow')),
            name='peak'
        ))
        # Raw‐COM (white X)
        fig.add_trace(go.Scatter(
            x=[com_raw_x[current_roi]], y=[com_raw_y[current_roi]],
            mode='markers',
            marker=dict(symbol='x', size=14, line=dict(width=2, color='white')),
            name='COM raw'
        ))
        # Baseline‐subtracted COM (cyan X)
        fig.add_trace(go.Scatter(
            x=[com_bs_x[current_roi]], y=[com_bs_y[current_roi]],
            mode='markers',
            marker=dict(symbol='x', size=14, line=dict(width=2, color='cyan')),
            name='COM BS'
        ))
        
        fig.update_layout(xaxis=dict(showticklabels=False), yaxis=dict(showticklabels=False),
                          plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)',
                          margin=dict(t=40,b=40,l=40,r=40))
        return fig, summary_fig, com_fig, base_title + " Activation Heatmap"

    # SINGLE-ROI GRID view
    fig = go.Figure(base_fig)
    grid = tracesAll[:,:,:,current_roi]
    amp = grid[:,:,idx0:idx1].mean(axis=2).flatten()
    norm = (amp - amp.min())/max(amp.max()-amp.min(),1e-6)
    cols = sample_colorscale('Viridis', norm.tolist())
    if show_trials:
        trials = normalized[:,:,:,current_roi]
        for i in range(40):
            r, c = divmod(i, 8)
            for t in trials[:,:,i]:
                fig.add_trace(go.Scattergl(x=time_vector, y=t,
                                           mode='lines', line=dict(color='gray', width=0.5),
                                           opacity=0.3), row=r+1, col=c+1)
    for i in range(40):
        r, c = divmod(i, 8)
        tr = fig.data[i]
        tr.x = time_vector
        tr.y = grid[r,c,:]
        tr.line.color = cols[i]
    fig.update_layout(plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
    fig.update_yaxes(range=[grid.min(), grid.max()])

    return fig, summary_fig, com_fig, base_title + " RF Traces"

if __name__ == '__main__':
    app.run(debug=True)
    #127.0.0.1:8050
    #http://127.0.0.1:8050/
