In [None]:
from caveclient import CAVEclient
import pandas as pd
import numpy as np 
import os
import pcg_skel
import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from meshparty.meshwork import algorithms
from meshparty import meshwork
from meshparty import trimesh_io, trimesh_vtk, skeletonize, mesh_filters
from nglui.statebuilder import *
import cloudvolume as cv
import connectome_create
import utils
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
now = datetime.datetime.now()
client = CAVEclient()
dataset = 'fanc_production_mar2021'
client = CAVEclient(dataset)

soma_table = client.materialize.query_table('soma_jan2022', timestamp=now)
# mn_df = client.materialize.query_table('motor_neuron_table_v7', timestamp = now)


In [None]:
pre_to_mn_df = connectome_create.load_pre_to_mn_df(ext='matched_typed_with_nt')

pre_to_mn_df.shape

In [None]:
pre_to_mn_df

In [None]:
left_mn_df = pre_to_mn_df.columns.to_frame()
left_index = left_mn_df.index
left_mn_df = left_mn_df.rename(columns={'segID':'pt_root_id'})

#Merging with the soma table in order to include soma position
left_mn_df = left_mn_df.merge(soma_table[['pt_root_id','pt_position']], how='left', 
                    on='pt_root_id', suffixes = ['_mn','_soma'])
left_mn_df.index = left_index
left_mn_df.head()


In [None]:
left_mn_df.index.to_frame().to_csv('./dfs_saved/mn_index')

In [None]:
mn_ids = left_mn_df.pt_root_id.tolist()
muscle_tuple_dict = utils.get_motor_pool_tuple_dict()
pool_keys = [
    'thorax_swing',
    'thorax_stance',
    'trochanter_extension',
    'trochanter_flexion',
    'femur_reductor',
    'tibia_extensor',
    'main_tibia_flexor',
    # 'auxiliary_tibia_flexor_A',
    'auxiliary_tibia_flexor_B',
    'auxiliary_tibia_flexor_E',
    'ltm',
    'tarsus_depressor_med_venU',
    'tarsus_depressor_noid',
    ]

for key in pool_keys:
    mn_tup = muscle_tuple_dict[key]
    left_mn_df.loc[mn_tup,'preferred_pool'] = key

left_mn_df.preferred_pool = left_mn_df.preferred_pool.astype("category")
left_mn_df.preferred_pool = left_mn_df.preferred_pool.cat.set_categories(pool_keys)

left_mn_df.sort_values(["preferred_pool"],kind='mergesort')

# Look at the right vs. left neurons

In [None]:
t1_mns_df = client.materialize.query_table('motor_neuron_table_v7',timestamp=connectome_create.timestamp)

In [None]:
# Use the pt_root_id as the index
t1_mns_df = t1_mns_df.set_index('pt_root_id')

## Have to find the right hand pair
1. First join pt positions from the table to the left_mn_df 
2. Take the ordered colums of left_mn_df
2. Join the left_pt_position on the pt_root_id. 
3. Join the right point_position on the left point position
4. Join the right pt_root_id from t1_mns_df

In [None]:
left_index_df = left_mn_df.index.to_frame()
left_index_df = left_index_df.rename({'segID':'pt_root_id'},axis='columns')
left_index_df = left_index_df.reset_index(drop=True)
left_with_ptpos = left_index_df.join(t1_mns_df,how='left',on='pt_root_id',lsuffix='',rsuffix='_y')

In [None]:
# This works if join finds the right seg ids. Run this on different days from the connectivity matrix, and there may be nans in the pt column.
left_with_ptpos.pt_position = left_with_ptpos["pt_position"].apply(lambda x: tuple(x))

# Now with pt_position as a tuple, can perform join on that column.

# If the above line failed, could try this to see where the mismatch is.
# for r in range(left_with_ptpos.shape[0]):
#     left_with_ptpos['pt_position'].iloc[r] = tuple(left_with_ptpos['pt_position'].iloc[r])
# left_with_ptpos['pt_position'].iloc[r]

In [None]:
t1_mns_df.pt_position = t1_mns_df["pt_position"].apply(lambda x: tuple(x))

t1_mns_df.reset_index(drop=False,inplace=True)
t1_mns_df.set_index('pt_position',drop=False,inplace=True)
t1_mns_df

# Save this reindexed df for once the right hand pt_positions are added.

In [None]:
t1_mns_df

# Import the paired matrix

In [None]:
paired_mns = pd.read_csv('./annotations_MN/Paired_MN_points_20221115.csv')
paired_mns.rename({'Soma':'pt_position_L','Unnamed: 4':'pt_position_R'},axis='columns',inplace=True)
unpaired_row_loc = paired_mns.pt_position_L.isna()
unpaired_row = paired_mns.loc[unpaired_row_loc,:]
paired_mns = paired_mns.dropna(axis='index')
unpaired_row
paired_mns

In [None]:
def make_num_list_from_str(x):
    x = x.split('[')[1]
    x = x.split(']')[0]
    x = x.split(' ')
    x = [j for j in x if j != '']
    y = [int(i) for i in x]
    return(y)


In [None]:
paired_mns.pt_position_L = paired_mns["pt_position_L"].apply(make_num_list_from_str)
paired_mns.pt_position_R = paired_mns["pt_position_R"].apply(make_num_list_from_str)

paired_mns.pt_position_L = paired_mns["pt_position_L"].apply(lambda x: tuple(x))
paired_mns.pt_position_R = paired_mns["pt_position_R"].apply(lambda x: tuple(x))



In [None]:
paired_mns.set_index('pt_position_L',drop=True, inplace=True)
paired_mns

# Join paired table on pt_position in left_with_ptpos

In [None]:
# joining pt_position from right to left
paired_mns_on_ptpos = left_with_ptpos.join(paired_mns,on='pt_position',how='left') 

In [None]:
paired_mns_on_ptpos

# Add R pt_root_id: Join paired table with T1 mns on pt_position_R

In [None]:
paired_df = paired_mns_on_ptpos.join(t1_mns_df,on='pt_position_R',how='left',rsuffix='_R')
stripped_paired_df = paired_df.set_index(['side','nerve','segment','function','muscle','rank','pt_root_id'])
stripped_paired_df

stripped_paired_df = stripped_paired_df[['classification_system','classification_system_R','cell_type','cell_type_R','pt_position','pt_position_R','pt_root_id_R']]
# stripped_paired_df = paired_df.drop(['id','valid','id_R','valid_R','classification_system','classification_system_R','Cell_type','cell_type_R','Classification','Classification.1','pt_supervoxel_id','pt_supervoxel_id_R'],axis='columns')
stripped_paired_df = stripped_paired_df.loc[:,~stripped_paired_df.columns.duplicated()].copy()
stripped_paired_df['pt_root_id'] = stripped_paired_df.index.get_level_values('pt_root_id')
stripped_paired_df

# Query the pairs of pt_root_ids and compare the number of input synapses.

In [None]:
lr_df = stripped_paired_df
utils.save_df_as_pickle(lr_df,name='left_right_paired_ptrootids')
lr_df

In [None]:
lr_df['input_syn_L'] = pd.NA
lr_df['input_syn_R'] = pd.NA

for idx,row in lr_df.iterrows():
    query_mns = lr_df.loc[idx,['pt_root_id','pt_root_id_R']]

    mn_inputs_df = client.materialize.synapse_query(post_ids = query_mns,timestamp=connectome_create.timestamp) # Takes list
    syn_in_conn=mn_inputs_df.groupby(by='post_pt_root_id').aggregate(len)['id'] # Series
    lr_df.loc[idx,'input_syn_L'] = syn_in_conn[query_mns[0]] # index into series
    lr_df.loc[idx,'input_syn_R'] = syn_in_conn[query_mns[1]]
    
    print('Left_MN {} has {} syn, Right_MN {} has {} syn'.format(query_mns[0],lr_df.loc[idx,'input_syn_L'] ,query_mns[1],lr_df.loc[idx,'input_syn_R']))
    # lr_df.iloc[r,'input_syn_L'] = mn_inputs_df.shape
    # mn_inputs_df = client.materialize.synapse_query(post_ids = lr_df.iloc[r,'pt_root_id'],timestamp=connectome_create.timestamp) # Takes list
    # lr_df.iloc[r,'input_syn_L'] = mn_inputs_df.shape
    

In [None]:
lr_df

# Now regress the left and right synapses

In [None]:
# lr_df = lr_df.drop('preferred_pool',axis='columns')
lrc_df = lr_df.copy()
lrc_df = lrc_df.rename(columns={'pt_position':'pt_position_L','pt_root_id':'pt_root_id_L'})
lrc_df

In [None]:
left_mn_df

In [None]:
preferred_pool_df = left_mn_df[['preferred_pool','pt_root_id']]
# preferred_pool_df = preferred_pool_df.rename(columns={'pt_root_id':'pt_root_id_L'})
preferred_pool_df = preferred_pool_df.set_index('pt_root_id',drop=True)
preferred_pool_df

In [None]:
# first, add back the preferred pool, but don't need to sort
lrc_df = lrc_df.join(preferred_pool_df,on='pt_root_id_L',how='left',sort=False)
lrc_df
# left_mn_df.index.to_list()

In [None]:
x = lrc_df.input_syn_L
y = lrc_df.input_syn_R

In [None]:
fig = plt.figure(1, figsize=(10, 10))
ax1 = plt.subplot2grid((1,1),(0,0))

ax1.scatter(x, y,marker='o')
plt.sca(ax1)
plt.title('Right inputs vs. Left inputs')
plt.xlabel('Left inputs')
plt.ylabel('Right inputs')

# plt.savefig('./figpanels/total_inputs.svg',format='svg')

In [None]:
from sklearn.linear_model import LinearRegression

y_ = y.to_numpy().reshape((-1,1))
x_ = x.to_numpy().reshape((-1,1))

reg = LinearRegression().fit(x_, y_)

fit = reg.predict(np.array([0,x_.max()]).reshape((-1,1)))


cat_pal = {
    'thorax_swing': '#A502AA',
    'thorax_stance': '#00A2B4',
    'trochanter_extension': '#D5CB6C',
    'trochanter_flexion': '#3F42A2',
    'femur_reductor': '#FF0000',
    'tibia_extensor': '#CC8544',
    'main_tibia_flexor': '#2E3191',
    'auxiliary_tibia_flexor_B': '#2DB515',
    'auxiliary_tibia_flexor_E': '#156005',
    'ltm': '#FFF100',
    'tarsus_depressor_med_venU': '#CECECE',
    'tarsus_depressor_noid': '#CECECE',
}
fig, ax = plt.subplots( 1, 1, figsize=(8,8))

# palette=utils.white_dense()

sns.scatterplot(data=lrc_df, x="input_syn_L", y="input_syn_R", hue="preferred_pool",palette = cat_pal,ax=ax)
ax.plot(np.array([0,x_.max()]).reshape((-1,1)),fit)
fig.savefig('./figpanels/Left_vs_Right_Synapses.eps',format='eps')

In [None]:
reg.score(x_,y_)

In [None]:
reg.score(x_[40:45,0].reshape((-1,1)),y_[40:45,0].reshape((-1,1)))

In [None]:
lrc_df.iloc[0:4]

In [None]:
from scipy.stats import pearsonr

pearsonr(x,y)

In [None]:
reg.coef_

In [None]:
fit