In [1]:
import pandas as pd
import pickle
import navis
import networkx as nx
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from core.Tree import Tree
from core.Branch import BranchSeq
from fafbseg import flywire
from core.Tree import Tree
import os
from core.GenericBranch import GenericBranchSeq



In [2]:
cell_df = pd.read_csv('neuron_annotation.tsv', sep='\t', index_col=0)
mb_ids = cell_df[cell_df['cell_class'] == 'Kenyon_Cell']['root_id'].values

with open('/drive_sdc/ssarup/flywire_data/mushroom_body/mb_branches.pkl', 'rb') as f:
    mb_branches = pickle.load(f)

mb_branches[720575940617650722][1].node_dict

  cell_df = pd.read_csv('neuron_annotation.tsv', sep='\t', index_col=0)


{1102: {'inc_synapse_pos': array([[644.84 , 138.744, 184.76 ]]),
  'pre_cell_id': [720575940621123113],
  'top_nts': ['acetylcholine'],
  'top_nt_conf': [0.9175052320949768],
  'flow': ['intrinsic'],
  'cell_type': ['Unknown'],
  'hemibrain_type': ['DL2d_adPN'],
  'cell_sub_class': ['uniglomerular'],
  'cell_class': ['ALPN'],
  'super_class': ['central'],
  'cleft_score': [142]},
 1121: {'inc_synapse_pos': array([[645.24 , 141.088, 184.36 ]]),
  'pre_cell_id': [720575940621123113],
  'top_nts': ['acetylcholine'],
  'top_nt_conf': [0.9175052320949768],
  'flow': ['intrinsic'],
  'cell_type': ['Unknown'],
  'hemibrain_type': ['DL2d_adPN'],
  'cell_sub_class': ['uniglomerular'],
  'cell_class': ['ALPN'],
  'super_class': ['central'],
  'cleft_score': [149]},
 1131: {},
 1038: {'inc_synapse_pos': array([[643.456, 141.032, 180.24 ]]),
  'pre_cell_id': [720575940621123113],
  'top_nts': ['acetylcholine'],
  'top_nt_conf': [0.9175052320949768],
  'flow': ['intrinsic'],
  'cell_type': ['Unknow

In [51]:
from core.flywire_utils import *

def find_presyn_from_branch(branch):
    post = branch.cell_id

    if sum(branch.inc_syn_per_node.values()) != len(branch.path):
        branch.reorder_node_inc_synapses()

    inc_cellid_sequence = []
    inc_synpos_sequence = []
    for node in branch.path:
        num_syn = branch.inc_syn_per_node[node]
        for syn in range(num_syn):
            inc_synpos_sequence.append(branch.node_dict[node]['inc_synapse_pos'][syn])
            inc_cellid_sequence.append(branch.node_dict[node]['pre_cell_id'][syn])

    # find any id that repeats at least 3 times in succession like [..., 'A', 'A', 'A', ...]
    # save their position
    repeat_ids = set()
    repeat_pos = {}
    curr_id = None
    for i in range(2, len(inc_cellid_sequence)):
        if inc_cellid_sequence[i] == inc_cellid_sequence[i-1] and inc_cellid_sequence[i] == inc_cellid_sequence[i-2]:
            repeat_ids.add(inc_cellid_sequence[i])
            if curr_id == inc_cellid_sequence[i]:
                repeat_pos[inc_cellid_sequence[i]].append(inc_synpos_sequence[i])
            else:
                repeat_pos[inc_cellid_sequence[i]] = [inc_synpos_sequence[i-2:i+1]]
            curr_id = inc_cellid_sequence[i]
        else:
            curr_id = None
    
    if len(repeat_ids) == 0:
        print('No repeating IDs found in', inc_cellid_sequence)
        return None, None, None
    
    print('IDs', repeat_ids, 'repeat in', inc_cellid_sequence)

    trees = load_tree_dict(repeat_ids)
    repeat_nodes = {}
    for cid in repeat_ids:
        G = trees[cid].graph
        pre_synapses = []
        for node in G.nodes:
            if post in G.nodes[node].get('post_cell_id', []):
                indices = np.where(np.array(G.nodes[node]['post_cell_id']) == post)[0]
                pre_synapses.extend(list(zip(G.nodes[node]['out_synapse_pos'][indices], len(indices)*[node], indices)))

        # collect all the presynapse positions, then find closest position to each repeat position
        pre_positions = [pos for pos, _, _ in pre_synapses]
        repeat_nodes[cid] = []
        repeat_pos[cid] = np.array(repeat_pos[cid]).squeeze()
        for pos in repeat_pos[cid]:
            dists = [np.linalg.norm(np.array(pos) - np.array(p)) for p in pre_positions]
            closest = np.argmin(dists)
            _, node, index = pre_synapses[closest]
            repeat_nodes[cid].append(((node, index), dists[closest]))
    
    return repeat_pos, repeat_nodes, trees



In [52]:
repeat_pos, repeat_nodes, trees = find_presyn_from_branch(mb_branches[720575940617650722][18])

IDs {720575940615366055} repeat in [720575940613583001, 720575940615366055, 720575940615366055, 720575940615366055]


100%|██████████| 1/1 [00:00<00:00, 12.22it/s]


In [53]:
repeat_pos, repeat_nodes, trees

({720575940615366055: array([[651.656, 134.664, 182.12 ],
         [654.02 , 135.192, 182.72 ],
         [653.608, 136.328, 181.92 ]])},
 {720575940615366055: [((8170, 14), 2.433765806317527),
   ((8170, 14), 0.11327841806800616),
   ((8165, 0), 0.13206059215375898)]},
 {720575940615366055: <core.Tree.Tree at 0x7fc4c1a43a10>})

In [54]:
nx.shortest_path_length(trees[720575940615366055].graph.to_undirected(), source=8170, target=8165)

1