In [None]:
#IMPORTATIONS#

# core importations
import caveclient
from caveclient import CAVEclient
import cloudvolume
from cloudvolume import CloudVolume
import pandas as pd
import numpy as np

# dash importations
import dash
from dash import Dash, dcc, html, Input, Output, State, dash_table

In [None]:
%%time
#FUNCTIONS#

# defines function to convert nm coordinates into FlyWire-usable
def coordConvert(coords):
    x = coords
    x[0] /= 4
    x[1] /= 4
    x[2] /= 40
    return x

# defines function to convert list of coordinates in [4,4,40] resolution to root id #
def coordsToRoot(coords):
    
    # sets client #
    client = CAVEclient("flywire_fafb_production")

    # sets cloud volume #
    cv = cloudvolume.CloudVolume("graphene://https://prod.flywire-daf.com/segmentation/1.0/fly_v31", use_https=True)

    # determines resolution of volume #
    res = cv.resolution

    # converts coordinates using volume resolution #
    cv_xyz = [int(coords[0]/(res[0]/4)),int(coords[1]/(res[1]/4)),int(coords[2]/(res[2]/40))]

    # sets point by passing converted coordinates into 'download_point' method #
    point = int(cv.download_point(cv_xyz, size=1))

    # looks up root id for that supervoxel using chunkedgraph #
    root_result = client.chunkedgraph.get_root_id(supervoxel_id=point)

    return root_result

# defines function to get [list of nuc ids, list of root ids] from list of nucleus ids
def nucToRoot(nucleus_ids):
    
    #sets client#
    client = CAVEclient("flywire_fafb_production")
    
    #determines current materialization version#
    mat_vers = max(client.materialize.get_versions())
    
    #makes nuc df using root ids#
    nuc_df = client.materialize.query_table('nuclei_v1',
                                            filter_in_dict={"id": nucleus_ids},
                                            materialization_version = mat_vers)
    
    #makes ouput list of root and nuc id lists aligned after possible reordering during query#
    output_list = [nuc_df['id'].tolist(), nuc_df['pt_root_id'].tolist()]
    
    #returns [[nucleus ids],[root ids]]#
    return output_list

# defines function to get synapse info using root ID#
def getSyn(root_id):
    
    #sets client#
    client = CAVEclient("flywire_fafb_production")
    
    #gets current materialization version#
    mat_vers = max(client.materialize.get_versions())
    
    #makes df of presynapses (outputs)#
    pre_syn_df = client.materialize.query_table('synapses_nt_v1', 
                                            filter_in_dict={"pre_pt_root_id":[root_id]})
    
    #makes df of postsynapses (inputs)#
    post_syn_df = client.materialize.query_table('synapses_nt_v1', 
                                            filter_in_dict={"post_pt_root_id":[root_id]})
    
    #calculates total pre and post synapses#
    pre_count = len(pre_syn_df)
    post_count = len(post_syn_df)
    
    #calculates neurotransmitter averages#
    pre_gaba_mean = round(pre_syn_df['gaba'].mean(),3)
    pre_ach_mean = round(pre_syn_df['ach'].mean(),3)
    pre_glut_mean = round(pre_syn_df['glut'].mean(),3)
    pre_oct_mean = round(pre_syn_df['oct'].mean(),3)
    pre_ser_mean = round(pre_syn_df['ser'].mean(),3)
    pre_da_mean = round(pre_syn_df['da'].mean(),3)
    post_gaba_mean = round(post_syn_df['gaba'].mean(),3)
    post_ach_mean = round(post_syn_df['ach'].mean(),3)
    post_glut_mean = round(post_syn_df['glut'].mean(),3)
    post_oct_mean = round(post_syn_df['oct'].mean(),3)
    post_ser_mean = round(post_syn_df['ser'].mean(),3)
    post_da_mean = round(post_syn_df['da'].mean(),3)
    
    #gets lists of pre and post synaptic partners#
    downstream_partners = len(pre_syn_df['post_pt_root_id'].unique())
    upstream_partners = len(post_syn_df['pre_pt_root_id'].unique())

    #makes blank output dataframe#
    out_df = pd.DataFrame()
    
    #assigns columns of output dataframe#
    out_df['Root ID'] = [root_id]
    out_df['Post Count'] = [post_count]
    out_df['Post Gaba Avg'] = [post_gaba_mean]
    out_df['Post Ach Avg'] = [post_ach_mean]
    out_df['Post Glut Avg'] = [post_glut_mean]
    out_df['Post Oct Avg'] = [post_oct_mean]
    out_df['Post Ser Avg'] = [post_ser_mean]
    out_df['Post Da Avg'] = [post_da_mean]
    out_df['Upstream Partners'] = [upstream_partners]
    out_df['Pre Count'] = [pre_count]
    out_df['Pre Gaba Avg'] = [pre_gaba_mean]
    out_df['Pre Ach Avg'] = [pre_ach_mean]
    out_df['Pre Glut Avg'] = [pre_glut_mean]
    out_df['Pre Oct Avg'] = [pre_oct_mean]
    out_df['Pre Ser Avg'] = [pre_ser_mean]
    out_df['Pre Da Avg'] = [pre_da_mean]
    out_df['Downstream Partners'] = [downstream_partners]
    
    #converts all data to strings#
    out_df = out_df.astype(str)
    
    return out_df

#defines function to build dataframe using id list or coords and query type#
def dfBuilder(id_list,input_type):
    
    #converts id list to strings to avoid rounding#
    id_list = [str(i) for i in id_list]
    
    
    #creates blank dataframe with correct column titles#
    out_df = pd.DataFrame(columns = ['Root ID','Post Count', 'Post Gaba Avg', 'Post Ach Avg', 'Post Glut Avg',
                                     'Post Oct Avg', 'Post Ser Avg', 'Post Da Avg', 'Upstream Partners',
                                     'Pre Count', 'Pre Gaba Avg', 'Pre Ach Avg', 'Pre Glut Avg',
                                     'Pre Oct Avg', 'Pre Ser Avg', 'Pre Da Avg', 'Downstream Partners'])

    #process if user chooses root id input#
    if input_type == 'root':

        #builds out_df by passing root ids into getSyn function#
        for i in id_list:
            df_row = getSyn(i)
            out_df = out_df.append(df_row)
    
    #process if user chooses nucleus id input#
    elif input_type == 'nuc':
        
        #passes id list into nucToRoot to get nuc_root_list#
        nuc_root_list = nucToRoot(id_list)
        
        #creates single-column dataframe of nucleus ids#
        nuc_df = pd.DataFrame(nuc_root_list[0], columns='Nucleus ID')
        
        #converts nuc_df data to strings to avoid issues when passing to Dash table#
        nuc_df = nuc_df.astype(str)
        
        #isolates root ids into their own list#
        root_list = nuc_root_list[1]
        
        #converts root ids to strings to avoid rounding and issues when passing to Dash table#
        root_list = [str(i) for i in root_list]
        
        #builds out_df by passing root ids into getSyn function#
        for i in root_list:
            df_row = getSyn(i)
            out_df = out_df.append(df_row)
            
        #joins nucleus ids to out_df#
        out_df = nuc_df.join(out_df)
    
    #process if user chooses coordinate input#
    elif input_type == 'coord':
        
        #converts coordinates to root id using coordToRoot function as string to avoid issues when passing to Dash table#
        root_id = str(coordToRoot(id_list))
        
        #builds out_df by passing root id into getSyn function (structured to allow multiple coords in future)#
        for i in id_list:
            df_row = getSyn(i)
            out_df = out_df.append(df_row)
    
    #reserved for error handling section, not implemented yet#
    else:
        print('ERROR')
    
    #returns output dataframe#
    return out_df


In [None]:
# DASH APP #

app = dash.Dash(__name__)

# defines layout of various app elements (submission field, checkboxes, button, output table)#
app.layout = html.Div([
    
    #defines text area for instructions and feedback#
    dcc.Textarea(
        id='message_text',
        value='Choose lookup method from dropdown, input coordinates, select output parameters, and click '\
        '"Submit" button.\nID queries are limited to 20 entries, coordinate lookups must be done one at a '\
        'time.\nLookup takes ~2-3 seconds per entry.',
        style={'width': '800px','resize': 'none'},
        rows=3,
        disabled=True,
    ),
    
    #defines input field#
    html.Div(dcc.Input(  
        id='input_field', 
        type='text', 
        placeholder='ID Number',
    )),
    
    html.Br(
    ),
    
    #defines dropdown menu for choosing query type#
    dcc.Dropdown(
        id='query_type',
        options=[
            {'label': 'Root ID', 'value': 'root'},
            {'label': 'Nucleus ID', 'value': 'nuc'},
            {'label': 'Coordinates (batch coordinate input not currently supported)', 'value': 'coord'},
        ],
        value='root',
        style={'max-width': '500px'},
    ),
    
    html.Br(
    ),
    
    #defines submission button#
    html.Button(  
        'Submit', 
        id='submit_button', 
        n_clicks=0,
    ),
    
    html.Br(
    ),
    
    #defines output table#
    html.Div(dash_table.DataTable(  
        id='table', 
        fill_width=False, #sets column width to fit text instead of expanding to container width# 
        export_format="csv",
    ))
])

#defines callback that takes root ids and desired data selection on button click and generates table#
@app.callback(
    Output('table','columns'),           #defines first output location as the 'columns' aspect of 'table'#
    Output('table', 'data'),             #defines second output location as the 'data' aspect of 'table'#
    Output('message_text','value'),      #defines second output location as the 'data' aspect of 'table'#
    Input('submit_button', 'n_clicks'),  #defines trigger as button press (change in the state of the 'n_clicks' aspect of 'submit_button')# 
    State('query_type', 'value'),        #defines first input state as value of 'query_type'#
    State('input_field', 'value'),       #defines second input state as the value of 'input_field'#
    prevent_initial_call=True,           #prevents function from being called on page load (prior to input)#
)
def update_output(n_clicks, query_method, ids):
    
    #splits 'ids' string into list#
    id_list = str(ids).split(",")
    
    #strips spaces from id_list entries and converts to integers#
    id_list = [int(x.strip(' ')) for x in id_list]
    
    #builds dataframe if 20-item threshold isn't exceeded#
    if len(id_list) <= 20:
        
        #passes id list and query method into daBuilder function to make dataframe#
        df = dfBuilder(id_list,query_method)
        
        #creates column list based on dataframe columns#
        column_list = [{"name": i, "id": i} for i in df.columns]
        
        #makes dictionary from dataframe#
        df_dict =  df.to_dict('records')
        
        #keeps message output the same#
        message_output = 'Choose lookup method from dropdown, input coordinates, select output parameters, '\
            'and click "Submit" button.\nID queries are limited to 20 entries, coordinate lookups must be '\
            'done one at a time.\nLookup takes ~2-3 seconds per entry.'
        
        #returns list of column names, data values, and message text#
        return [column_list,df_dict,message_output]               
    
    #returns error message if 20-item threshold is exceeded#
    else:
        return [0,0,'Please limit each query to a maximum of 20 id numbers.']
        

if __name__ == '__main__':
    app.run_server()