In [153]:
#IMPORTATIONS#

# core importations
import caveclient
from caveclient import CAVEclient
import cloudvolume
from cloudvolume import CloudVolume
import pandas as pd
import numpy as np

# dash importations
import dash
from dash import Dash, dcc, html, Input, Output, State, dash_table

In [154]:
#FUNCTIONS#

# defines function to convert nm coordinates into FlyWire-usable 4,4,40 #
def coordConvert(coords):
    x = coords
    x[0] /= 4
    x[1] /= 4
    x[2] /= 40
    return x

# defines function to convert list of x,y,z coordinates in [4,4,40] resolution to root id #
def coordsToRoot(coords):
    
    # converts coordinates to ints #
    coords = list(map(int,coords))

    # sets client #
    client = CAVEclient("flywire_fafb_production")

    # sets cloud volume #
    cv = cloudvolume.CloudVolume("graphene://https://prod.flywire-daf.com/segmentation/1.0/fly_v31", use_https=True)

    # determines resolution of volume #
    res = cv.resolution

    # converts coordinates using volume resolution #
    cv_xyz = [
        int(coords[0]/(res[0]/4)),
        int(coords[1]/(res[1]/4)),
        int(coords[2]/(res[2]/40))
        ]

    # sets point by passing converted coordinates into 'download_point' method #
    point = int(cv.download_point(cv_xyz, size=1))

    # looks up root id for that supervoxel using chunkedgraph, converts to string #
    root_result = str(client.chunkedgraph.get_root_id(supervoxel_id=point))

    return root_result

# defines function for querying nucleus table using list of root or nuc ids #
# and query type ('nuc' or 'root') #
def getNuc(id_list):
    
    # sets client #
    client = CAVEclient("flywire_fafb_production")
    
    # gets current materialization version #
    mat_vers = max(client.materialize.get_versions())
    
    # pulls nucleus table results based on query type #
    if len(id_list[0]) == 7:
        nuc_df = client.materialize.query_table(
            'nuclei_v1',
            filter_in_dict={"id": id_list},
            materialization_version = mat_vers
            )
    elif len(id_list[0]) == 18:
        nuc_df = client.materialize.query_table(
            'nuclei_v1',
            filter_in_dict={"pt_root_id": id_list},
            materialization_version = mat_vers
            )

    # converts nucleus coordinates from n to 4x4x40 resolution #    
    nuc_df['pt_position'] = [coordConvert(i) for i in nuc_df['pt_position']]
    
    # creates output dataframe using root id, nuc id, and nuc coords from table to keep alignment #
    out_df = pd.DataFrame({
        'Root ID':list(nuc_df['pt_root_id']),
        'Nucleus ID':list(nuc_df['id']),
        'Nucleus Coordinates':list(nuc_df['pt_position'])
        })
        
    return out_df.astype(str)

# defines function to create df of nt averages by passing: #
# list of partner ids, pre- or post-synapse df, and matching column name #
def ntMeans(ids,df,col_name):

    # makes blank output dataframe #
    out_df = pd.DataFrame()
    
    # iterates through partner ids #
    for x in ids:

        #!!! MAY BE PROBLEM !!!#
        # filters main df to only include entries for partner #
        partner_df = (df.loc[df[col_name] == x]).reset_index(drop = True)

        # creates row dataframe and fills with nt avgs#
        row_df = pd.DataFrame({'Partner ID':[x]})
        row_df['Gaba Avg'] = [round(partner_df['gaba'].mean(),3)]
        row_df['Ach Avg'] = [round(partner_df['ach'].mean(),3)]
        row_df['Glut Avg'] = [round(partner_df['glut'].mean(),3)]
        row_df['Oct Avg'] = [round(partner_df['oct'].mean(),3)]
        row_df['Ser Avg'] = [round(partner_df['ser'].mean(),3)]
        row_df['Da Avg'] = [round(partner_df['da'].mean(),3)]

        # adds row df to output df #
        out_df = out_df.append(row_df).reset_index(drop = True)

    return out_df

# defines function to validate constructed synapse data #
def checkValid(up_df, down_df, incoming_df, outgoing_df):

    # sets counter #
    counter = 0

    # validates upstream partners #
    for x in up_df.index:
        built_con = up_df.loc[x,'Connections']
        quer_con = str(
            list(incoming_df['pre_pt_root_id'].astype(str)).count(
                up_df.loc[x,'Upstream Partner ID']))
        
        if(
            built_con == quer_con
            # round(np.mean(out_df['gaba']),3) == row_df.loc[0,'Gaba Avg']
            # round(np.mean(out_df['ach']),3) == row_df.loc[0,'Ach Avg'] and
            # round(np.mean(out_df['glut']),3) == row_df.loc[0,'Glut Avg'] and
            # round(np.mean(out_df['oct']),3) == row_df.loc[0,'Oct Avg'] and
            # round(np.mean(out_df['ser']),3) == row_df.loc[0,'Ser Avg'] and
            # round(np.mean(out_df['da']),3) == row_df.loc[0,'Da Avg']
        ):
            counter += 1
        else:
            failed = counter + ' items validated. Upstream data false for partner '+ up_df.loc[x,'Upstream Partner ID'] + \
                    '. Built count = ' + built_con + '. Query count = ' + quer_con
            return failed
    
    # validates downstream partners #
    for x in down_df.index:
        built_con = down_df.loc[x,'Connections']
        quer_con = str(
            list(outgoing_df['post_pt_root_id'].astype(str)).count(
                down_df.loc[x,'Downstream Partner ID']))
        
        if(
            built_con == quer_con
        ):
            counter += 1
        else:
            failed = counter + ' items validated. Downstream data false for partner '+ down_df.loc[x,'Downstream Partner ID'] + \
                    '. Built count = ' + built_con + '. Query count = ' + quer_con
            return failed
    
    return 'All ' + str(counter) + ' items have been validated.'

# defines function to get synapse info using root ID#
def getSyn(root_id, cleft_thresh=0, validate=False):

    #sets client#
    client = CAVEclient("flywire_fafb_production")
    
    #gets current materialization version#
    mat_vers = max(client.materialize.get_versions())
    
    #makes dfs of pre- (outgoing) and post- (incoming) synapses #
    outgoing_syn_df = client.materialize.query_table(
        'synapses_nt_v1',
        filter_in_dict={"pre_pt_root_id":root_id},
        materialization_version = mat_vers
        )
    incoming_syn_df = client.materialize.query_table(
        'synapses_nt_v1',
        filter_in_dict={"post_pt_root_id":root_id},
        materialization_version = mat_vers
        )

    # removes synapses below cleft threshold, 0-roots, and autapses #
    outgoing_syn_df = outgoing_syn_df[outgoing_syn_df['cleft_score'] >= cleft_thresh].reset_index(drop = True)
    outgoing_syn_df = outgoing_syn_df[outgoing_syn_df["pre_pt_root_id"] != outgoing_syn_df["post_pt_root_id"]].reset_index(drop = True)
    outgoing_syn_df = outgoing_syn_df[outgoing_syn_df["post_pt_root_id"] != 0].reset_index(drop = True)
    incoming_syn_df = incoming_syn_df[incoming_syn_df['cleft_score'] >= cleft_thresh].reset_index(drop = True)
    incoming_syn_df = incoming_syn_df[incoming_syn_df["pre_pt_root_id"] != incoming_syn_df["post_pt_root_id"]].reset_index(drop = True)
    incoming_syn_df = incoming_syn_df[incoming_syn_df["post_pt_root_id"] != 0].reset_index(drop = True)

    # calculates total synapses #
    in_count = len(incoming_syn_df)
    out_count = len(outgoing_syn_df)
    
    # gets lists of pre and post synaptic partners #
    downstream_partners = list(outgoing_syn_df.drop_duplicates(subset = 'post_pt_root_id')['post_pt_root_id'])
    upstream_partners = list(incoming_syn_df.drop_duplicates(subset = 'pre_pt_root_id')['pre_pt_root_id'])

    # calculates number of upstream and downstream partners #
    up_count = len(upstream_partners)
    down_count = len(downstream_partners)

    # builds output dataframes #
    summary_df = pd.DataFrame({
        'Root ID':root_id,
        'Incoming':in_count,
        'Outgoing':out_count,
        'Upstream Partners':up_count,
        'Downstream Partners':down_count
        })
    up_df = pd.DataFrame({'Partner ID':upstream_partners})
    down_df = pd.DataFrame({'Partner ID':downstream_partners})

    # adds number of connections between input neuron and partners #
    up_df['Connections'] = [list(incoming_syn_df['pre_pt_root_id']).count(x) for x in upstream_partners]
    down_df['Connections'] = [list(outgoing_syn_df['post_pt_root_id']).count(x) for x in downstream_partners]

    # adds neurotransmitter averages for each partner #
    up_df = up_df.join(
        ntMeans(
            upstream_partners,
            incoming_syn_df,
            'pre_pt_root_id'
            ).set_index('Partner ID'), 
        on='Partner ID'
        )
    down_df = down_df.join(
        ntMeans(
            downstream_partners,
            outgoing_syn_df,
            'post_pt_root_id'
            ).set_index('Partner ID'), 
        on='Partner ID'
        )

    # renames partner id columns to up/downstream #
    up_df = up_df.rename(columns={"Partner ID": "Upstream Partner ID"})
    down_df = down_df.rename(columns={"Partner ID": "Downstream Partner ID"})

    # converts all data to strings #
    summary_df = summary_df.astype(str)
    up_df = up_df.astype(str)
    down_df = down_df.astype(str)
    
    # runs data validation if input variable is set to True #
    if validate == True:
        val_out = checkValid(up_df, down_df, incoming_syn_df, outgoing_syn_df)
        return [summary_df,up_df,down_df, val_out]
    else:
        return [summary_df,up_df,down_df,'Data not validated']

# defines function to build dataframe using list-formatted root/nuc id or coords #
def dfBuilder(input_list, cleft_thresh, validate):

    # if coordinates detected, converts to root #
    if len(input_list) == 3:
        input_list = [coordsToRoot(input_list)]

    # uses root or nuc id to build nuc df #
    nuc_df = getNuc(input_list)

    # uses root id to build synapse dataframes #
    syn_sum_df, up_df, down_df, val_status = getSyn(
        [str(nuc_df.loc[0,'Root ID'])], 
        cleft_thresh,
        validate,
        )

    # joins synapse summary to nucleus df to create summary df
    sum_df = nuc_df.join(
        syn_sum_df.set_index('Root ID'), 
        on='Root ID'
        )
    
    #returns output dataframes#
    return [sum_df, up_df, down_df, val_status]

In [155]:
# client = CAVEclient("flywire_fafb_production")
# mat_vers = max(client.materialize.get_versions())
# test_df = client.materialize.query_table(
#     'synapses_nt_v1',
#     filter_in_dict={
#         "pre_pt_root_id":[720575940628522967],
#         # "post_pt_root_id":[post_root_id],
#         },
#     materialization_version = mat_vers
# )

# test_df.head(1)

In [156]:
# DASH APP #

app = dash.Dash(__name__)

# defines layout of various app elements (submission field, checkboxes, button, output table)#
app.layout = html.Div([
    
    #defines text area for instructions and feedback#
    dcc.Textarea(
        id='message_text',
        value='Input root/nuc ID or coordinates and click "Submit" button.\n'\
            'Only one entry at a time.',
        style={'width': '500px','resize': 'none'},
        rows=2,
        disabled=True,
    ),
    
    #defines input field#
    html.Div(dcc.Input(  
        id='input_field', 
        type='text', 
        placeholder='Root/Nuc ID or Coordinates',
    )),
    
    html.Br(
    ),

    # defines message explaining cleft score field #
    dcc.Textarea(
        id='cleft_message_text',
        value='Cleft score threshold for synapses:',
        style={'width': '400px','resize': 'none'},
        rows=1,
        disabled=True,
    ),

    # defines input field for cleft score threshold #
    html.Div(dcc.Input(  
        id='cleft_thresh_field', 
        type='number',
        value=50,
        
    )),

    html.Br(
    ),

    # defines validation checkbox #
    html.Div(dcc.Checklist(  
        id='val_check',
        options=[{'label': 'Data Validation','value': True}],
        labelStyle={'display': 'block'},
    )),

    #defines submission button#
    html.Button(  
        'Submit', 
        id='submit_button', 
        n_clicks=0,
        style={'margin-top': '15px','margin-bottom': '15px'}
    ),
    
    html.Br(
    ),
    
    #defines summary table#
    html.Div(dash_table.DataTable(  
        id='summary_table', 
        fill_width=False, #sets column width to fit text instead of expanding to container width# 
        export_format="csv",
    )),

    html.Br(
    ),

    #defines incoming table#
    html.Div(dash_table.DataTable(  
        id='incoming_table', 
        # sets column width to fit text instead of expanding to container width #
        # fill_width=False, 
        export_format="csv",
        style_table={'height': '180px', 'overflowY': 'auto'},
        page_action='none',
        fixed_rows={'headers': True},
        style_cell={'width': 160},
    )),

    html.Br(
    ),
    
    #defines outgoing table#
    html.Div(dash_table.DataTable(  
        id='outgoing_table', 
        # sets column width to fit text instead of expanding to container width #
        # fill_width=False,  
        export_format="csv",
        style_table={'height': '180px', 'overflowY': 'auto'},
        page_action='none',
        fixed_rows={'headers': True},
        style_cell={'width': 160},
    ))
])

#defines callback that takes root ids and desired data selection on button click and generates table#
@app.callback(
    Output('summary_table','columns'),
    Output('summary_table', 'data'),
    Output('incoming_table','columns'),
    Output('incoming_table', 'data'),
    Output('outgoing_table','columns'),
    Output('outgoing_table', 'data'),
    Output('message_text','value'),
    Input('submit_button', 'n_clicks'),  #defines trigger as button press (change in the state of the 'n_clicks' aspect of 'submit_button')# 
    State('input_field', 'value'),
    State('cleft_thresh_field','value'),
    State('val_check','value'),
    prevent_initial_call=True,           #prevents function from being called on page load (prior to input)#
)
def update_output(n_clicks, input_list, cleft_thresh,val_choice):

    # splits 'ids' string into list #
    input_list = str(input_list).split(",")
    
    # strips spaces from id_list entries and converts to integers #
    input_list = [str(x.strip(' ')) for x in input_list]

    # builds output if 1-item threshold isn't exceeded #
    if len(input_list) == 1 or len(input_list) == 3 and len(input_list[0]) != len(input_list[2]):

        if val_choice == True:
            val_in = True
        else:
            val_in = False

        # sets dataframes by passing id/coords into dfBuilder function #
        sum_df, up_df, down_df , val_status = dfBuilder(input_list, cleft_thresh, val_in)

        # creates column lists based on dataframe columns #
        sum_column_list = [{"name": i, "id": i} for i in sum_df.columns]
        up_column_list = [{"name": i, "id": i} for i in up_df.columns]
        down_column_list = [{"name": i, "id": i} for i in down_df.columns]
        
        # makes dictionaries from dataframes #
        sum_dict =  sum_df.to_dict('records')
        up_dict =  up_df.to_dict('records')
        down_dict =  down_df.to_dict('records')

        # changes message to reflect validation status #
        message_output = val_status
        
        #returns list of column names, data values, and message text#
        return [
            sum_column_list, 
            sum_dict, 
            up_column_list, 
            up_dict, 
            down_column_list, 
            down_dict, 
            message_output
        ]               
    
    # returns error message if 1-item threshold is exceeded #
    else:
        return [0,0,0,0,0,0,'Please limit each query to one entry.']
        
if __name__ == '__main__':
    app.run_server()

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is run

 * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)
127.0.0.1 - - [19/Nov/2021 15:03:00] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [19/Nov/2021 15:03:00] "GET /_dash-layout HTTP/1.1" 200 -
127.0.0.1 - - [19/Nov/2021 15:03:00] "GET /_dash-dependencies HTTP/1.1" 200 -
127.0.0.1 - - [19/Nov/2021 15:03:00] "GET /_favicon.ico?v=2.0.0 HTTP/1.1" 200 -
127.0.0.1 - - [19/Nov/2021 15:03:00] "GET /_dash-component-suites/dash/dash_table/async-highlight.js HTTP/1.1" 200 -
127.0.0.1 - - [19/Nov/2021 15:03:00] "GET /_dash-component-suites/dash/dash_table/async-table.js HTTP/1.1" 200 -
127.0.0.1 - - [19/Nov/2021 15:03:05] "POST /_dash-update-component HTTP/1.1" 200 -
