In [None]:
#IMPORTATIONS#

# core importations #
import caveclient
from caveclient import CAVEclient
import cloudvolume
from cloudvolume import CloudVolume
import pandas as pd
import numpy as np

# dash importations #
import dash
from dash import Dash, dcc, html, Input, Output, State, dash_table

In [None]:
# FUNCTIONS #

# defines function for converting nanometer coordinates to 4x4x40 resolution #
def coordConvert(coords):
    x = coords
    x[0] /= 4
    x[1] /= 4
    x[2] /= 40
    x[0] = str(x[0])
    x[1] = str(x[1])
    x[2] = str(x[2])
    return x

# defines function to convert list of coordinates in [4,4,40] resolution to root id #
def coordsToRoot(coords):

    # sets client #
    client = CAVEclient("flywire_fafb_production")

    # sets cloud volume #
    cv = cloudvolume.CloudVolume("graphene://https://prod.flywire-daf.com/segmentation/1.0/fly_v31", use_https=True)

    # determines resolution of volume #
    res = cv.resolution

    # converts coord strings to ints #
    coords = [int(i) for i in coords]

    # converts coordinates using volume resolution #
    cv_xyz = [int(coords[0]/(res[0]/4)),int(coords[1]/(res[1]/4)),int(coords[2]/(res[2]/40))]

    # sets point by passing converted coordinates into 'download_point' method #
    point = int(cv.download_point(cv_xyz, size=1))

    # looks up root id for that supervoxel using chunkedgraph #
    root_result = str(client.chunkedgraph.get_root_id(supervoxel_id=point))

    return root_result

# defines function for querying nucleus table using list of root or nuc ids #
# and query type ('nuc' or 'root') #
def getNucs(ids, query_type):
    
    # sets client #
    client = CAVEclient("flywire_fafb_production")
    
    # gets current materialization version #
    mat_vers = max(client.materialize.get_versions())
    
    # pulls nucleus table results based on query type #
    if query_type == 'nuc':
        nuc_df = client.materialize.query_table('nuclei_v1',
                                                filter_in_dict={"id": ids},
                                                materialization_version = mat_vers)
    elif query_type == 'root':
        nuc_df = client.materialize.query_table('nuclei_v1',
                                            filter_in_dict={"pt_root_id": ids},
                                            materialization_version = mat_vers)

    # converts nucleus coordinates from n to 4x4x40 resolution #    
    nuc_df['pt_position'] = [coordConvert(i) for i in nuc_df['pt_position']]
    
    
    
    # creates output dataframe using root id, nuc id, and nuc coords from table to keep alignment #
    out_df = pd.DataFrame({'Root ID':list(nuc_df['pt_root_id'].astype('str')), # converts int ids to str #
                           'Nucleus ID':list(nuc_df['id']),
                           'Nucleus Coordinates':list(nuc_df['pt_position'])})
        
    return out_df.astype(str)

# defines function to get presynaptic table entry using root id #
def getSyns(root_ids, cleft_thresh=0):
    
    # sets client #
    client = CAVEclient("flywire_fafb_production")
    
    # gets current materialization version #
    mat_vers = max(client.materialize.get_versions())
    
    # creates blank output dataframe #
    out_df = pd.DataFrame(columns = ['Root ID','Incoming Synapses','Outgoing Synapses',
                                     'Upstream Parters','Downstream Partners'])
    
    
    # iterates through root ids, creating df rows and adding them to output df #
    for i in root_ids:
    
        # gets pre and post synapse tables using root id #
        pre_syn_df = client.materialize.query_table('synapses_nt_v1', 
                                                filter_in_dict={"pre_pt_root_id":[i]})
        post_syn_df = client.materialize.query_table('synapses_nt_v1', 
                                                filter_in_dict={"post_pt_root_id":[i]})

        # removes false positives by filtering out synapses with cleft scores under 50 #
        pre_syn_df = pre_syn_df.loc[pre_syn_df['cleft_score'] >= cleft_thresh].reset_index(drop = True)
        post_syn_df = post_syn_df.loc[post_syn_df['cleft_score'] >= cleft_thresh].reset_index(drop = True)

        # counts total pre- and post-synapses #
        pre_num = len(pre_syn_df)
        post_num = len(post_syn_df)

        # gets lists of pre- and post-synaptic partners #
        downstream_num = len(pre_syn_df['post_pt_root_id'].unique())
        upstream_num = len(post_syn_df['pre_pt_root_id'].unique())

        # removes any false positives by filtering out non-'t' in valid_nt column #
        pre_syn_df = pre_syn_df.loc[pre_syn_df['valid_nt'] == 't'].reset_index(drop = True)
        post_syn_df = post_syn_df.loc[post_syn_df['valid_nt'] == 't'].reset_index(drop = True)

        # calculates neurotransmitter averages #
        pre_gaba_mean = round(pre_syn_df['gaba'].mean(),3)
        pre_ach_mean = round(pre_syn_df['ach'].mean(),3)
        pre_glut_mean = round(pre_syn_df['glut'].mean(),3)
        pre_oct_mean = round(pre_syn_df['oct'].mean(),3)
        pre_ser_mean = round(pre_syn_df['ser'].mean(),3)
        pre_da_mean = round(pre_syn_df['da'].mean(),3)
        post_gaba_mean = round(post_syn_df['gaba'].mean(),3)
        post_ach_mean = round(post_syn_df['ach'].mean(),3)
        post_glut_mean = round(post_syn_df['glut'].mean(),3)
        post_oct_mean = round(post_syn_df['oct'].mean(),3)
        post_ser_mean = round(post_syn_df['ser'].mean(),3)
        post_da_mean = round(post_syn_df['da'].mean(),3)
        
        # makes row dataframe #
        row_df = pd.DataFrame({'Root ID':[i], 'Incoming Synapses':[post_num], 'Outgoing Synapses':[pre_num],
        'Upstream Parters':[upstream_num], 'Downstream Partners':[downstream_num],
        'Post Gaba Avg':[post_gaba_mean], 'Post Ach Avg':[post_ach_mean], 'Post Glut Avg':[post_glut_mean],
        'Post Oct Avg':[post_oct_mean], 'Post Ser Avg':[post_ser_mean], 'Post Da Avg':[post_da_mean],
        'Pre Gaba Avg':[pre_gaba_mean], 'Pre Ach Avg':[pre_ach_mean], 'Pre Glut Avg':[pre_glut_mean],
        'Pre Oct Avg':[pre_oct_mean], 'Pre Ser Avg':[pre_ser_mean], 'Pre Da Avg':[pre_da_mean],})
        
        # appends row onto output df #
        out_df = out_df.append(row_df, ignore_index = True)
    
    return out_df.astype(str)

# defines function to query changelog for edit data #
def getEdits(root_ids):
    
    # sets client #
    client = CAVEclient("flywire_fafb_production")
    
    # creates blank output dataframe #
    out_df = pd.DataFrame()
    
    # iterates through root ids to fill output df #
    for i in root_ids:
        
        # attempts to get changelog and assign edit values using root id #
        try:

            # gets changelog dictionary using root id #
            change_dict = client.chunkedgraph.get_change_log(i)

            # sets edit data objects #
            splits = change_dict['n_splits']
            merges = change_dict['n_mergers']
            edits = splits + merges
        
        # handles exception if query returns error #
        except:
            splits = 'n/a'
            merges = 'n/a'
            edits = 'n/a'
        
        # creates row df from edit data objects #
        row_df = pd.DataFrame({'Root ID': [i],'Splits':[splits],'Merges':[merges],
                               'Total Edits':[edits]})
        
        # appends row onto output df #
        out_df = out_df.append(row_df)

    return out_df.astype(str) 

# defines function to query proofreading table using list of root ids #
def getProof(root_list):

    # sets client #
    client = CAVEclient('flywire_fafb_production')

    # queries proofreading table using root ids #
    proof_df = client.materialize.query_table('proofreading_status_public_v1', 
                                                filter_in_dict={'pt_root_id':root_list})



    # creates balnk lists for proofreading status and proofreader id #
    proofed_list = []
    proofreader_list = []

    # iterates over root ids and attempts to fill in proofreading status and proofreader id #
    for i in root_list:
        try:
            proofed_list.append(proof_df.set_index('pt_root_id').loc[int(i),'proofread'] == 't')
        except:
            proofed_list.append(False)
        try:
            proofreader_list.append(proof_df.set_index('pt_root_id').loc[int(i),'user_id'])
        except:
            proofreader_list.append('n/a')

    #creates output df using lists # 
    out_df = pd.DataFrame({'Root ID':root_list,
                            'Proofread':proofed_list,
                            'Proofreader':proofreader_list})

    return out_df  

# defines function to build dataframe using root id list #
# and options (list of keywords based on checkbox input) #
def dfBuilder(id_list, options, cleft_thresh):

    # sets default query type to root id #
    query_type = 'root'

    # automatically determines data type based on length of input ids #
    if all([len(str(i)) == 18 for i in id_list]):

        # sets output df to root id list #
        out_df = pd.DataFrame({'Root ID':id_list})

        # creates nuc dataframe using getNucs #
        nuc_df = getNucs(id_list,query_type)

        # joins nuc_df to out_df #
        out_df = out_df.join(nuc_df.set_index('Root ID'), on='Root ID')

        # sets root list equal to id list #
        root_list = id_list

    elif all([len(str(i)) == 7 for i in id_list]):
        
        # changes query type to nucleus ids #
        query_type = 'nuc'

        # creates output dataframe using getNucs, since all ids will have associated nuclei #
        out_df = getNucs(id_list,query_type)

        # sets root list using id column of nuc df #
        root_list = out_df['Root ID']

    elif len(id_list) % 3 == 0:
        
        # sets root list by converting coords to root id #
        root_list = [coordsToRoot(id_list)]

        # sets output df to root id #
        out_df = pd.DataFrame({'Root ID':root_list})

        # creates nuc dataframe using getNucs #
        nuc_df = getNucs(root_list,query_type)

        # joins nuc_df to out_df #
        out_df = out_df.join(nuc_df.set_index('Root ID'), on='Root ID')

    out_df.insert(1,"Root Current", client.chunkedgraph.is_latest_roots(root_list))
    
    # adds proofreading dat if checkbox is marked #
    if 'proof' in options:
        
        #creates proof_df using root list #
        proof_df = getProof(root_list)

        # joins proof_df to out_df #
        out_df = out_df.join(proof_df.set_index('Root ID'), on='Root ID')

    # adds edit data if edits checkbox is marked CHANGE THIS TO TABULAR LOG #
    if 'edits' in options: 
        
        # creates edit_df using root list #
        edit_df = getEdits(root_list)
        
        # joins edit_df to out_df #
        out_df = out_df.join(edit_df.set_index('Root ID'), on='Root ID')

        # adds synapse data if synapse checkbox is marked #
    if 'syns' in options:
        
        # creates syn_df by passing root list into getSyns #
        syn_df = getSyns(root_list, cleft_thresh)
        
        # joins syn_df to out_df #
        out_df = out_df.join(syn_df.set_index('Root ID'), on='Root ID')

          
    return out_df.astype(str) 

In [None]:
# DASH APP #

app = dash.Dash(__name__)

# defines layout of various app elements (submission field, checkboxes, button, output table) #
app.layout = html.Div([
    # defines text area to relay messages #
    dcc.Textarea(
        id='message_text',
        value='Input coordinates, select output parameters, and click "Submit" button.\n'\
        'ID queries are limited to 20 entries, coordinate lookups must be done one at a time.\n'\
        'Lookup takes ~2-3 seconds per entry.',
        style={'width': '650px','resize': 'none'},
        rows=3,
        disabled=True,
    ),
    # defines input field for ids #
    html.Div(dcc.Input(  
        id='input_field', 
        type='text', 
        placeholder='ID Number',
    )),
    html.Br(
    ),
    # defines data selection checkboxes #
    html.Div(dcc.Checklist(  
        id='checkboxes',
        options=[
            {'label': 'Proofreading Status and Proofreader ID', 'value': 'proof'},
            {'label': 'Splits, Mergers, and Total Edits', 'value': 'edits'},
            {'label': 'Synapse Counts and Neurotransmitters', 'value': 'syns'},
        ],
        labelStyle={'display': 'block'},
        value=['syns','edits','proof'],
    )),
    # defines message explaining cleft score field #
    dcc.Textarea(
        id='cleft_message_text',
        value='Cleft score threshold for synapses (default is 50)',
        style={'width': '400px','resize': 'none'},
        rows=1,
        disabled=True,
    ),
    # defines input field for cleft score threshold #
    html.Div(dcc.Input(  
        id='cleft_thresh_field', 
        type='number',
        value=50,
        
    )),
    html.Br(
    ),
    # defines submission button #
    html.Button(  
        'Submit', 
        id='submit_button', 
        n_clicks=0,
    ),
    html.Br(
    ),
    # defines output table #
    html.Div(dash_table.DataTable(  
        id='table', 
        fill_width=False,  # sets column width to fit text instead of expanding to container width #
        export_format="csv",
    ))
])

# defines callback that takes root ids and desired data selection on button click and generates table #
@app.callback(
    Output('table','columns'),           # defines first output location as 'columns' aspect of 'table' #
    Output('table', 'data'),             # defines second output location as 'data' aspect of 'table' #
    Output('message_text','value'),      # defines second output location as 'data' aspect of 'table' #
    Input('submit_button', 'n_clicks'),  # defines trigger as button press # 
    State('input_field', 'value'),       # defines first input state as value of 'input_field' #
    State('checkboxes','value'),         # defines second input state as value of 'checkboxes' #
    State('cleft_thresh_field','value'),  # defines third input state as value of 'cleft_thresh_field' #
    prevent_initial_call=True,           # prevents function from executing prior to button press #
)
def update_output(n_clicks, ids, checked, cleft_thresh):
    
    # splits 'roots' string into list and strips spaces #
    id_list = [x.strip(' ') for x in str(ids).split(",")] 
    
    # if query list under 20 items, generates dataframe #
    if len(id_list) <= 20:
        
        # creates df using dfBuilder function #
        df = dfBuilder(id_list, checked, cleft_thresh)
        
        # converts nucleus coords to str to avoid issues when passing to Dash table #
        df['Nucleus Coordinates'] = df['Nucleus Coordinates'].astype('str')
        
        # creates column list based on dataframe columns #
        column_list = [{"name": i, "id": i} for i in df.columns]
        
        # converts df to dictionary to pass to Dash table #
        df_dict =  df.to_dict('records')
        
        # keeps message output the same #
        message_output = 'Input coordinates, select output parameters, and click "Submit" button.\n'\
        'ID queries are limited to 20 entries, coordinate lookups must be done one at a time.\n'\
        'Lookup takes ~2-3 seconds per entry.'
        
        # builds output list #
        output_list = [column_list,df_dict,message_output]        #combines df_dict and column_list into output list#
        
        return output_list                                        #returns list of column names and data values#
    
    # returns error message if query list is longer than 20 items #
    else:
        return [0,0,'Please limit each query to a maximum of 20 items.']
        

if __name__ == '__main__':
    app.run_server()