In [1]:
import pickle
import pandas as pd


In [2]:
cell_df = pd.read_csv('./data/post_ids.csv')
cell_df.head()

Unnamed: 0,pt_root_id
0,3531662040
1,3691674312
2,3967005302
3,4911602202
4,23635194658


In [3]:
cell_ids = cell_df['pt_root_id'].tolist()

In [4]:
len(cell_ids)

8380

In [5]:
import os

found_keys = set()
for fname in os.listdir('./data/trees'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/trees',fname), 'rb') as f:
            trees = pickle.load(f)
        found_keys.update(trees.keys())

  @jit(parallel=True, fastmath=True, nopython=False)
  @jit(parallel=True, fastmath=True, nopython=False)


In [6]:
len(found_keys)

4420

In [7]:
for fname in os.listdir('./data/trees2'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/trees2',fname), 'rb') as f:
            trees = pickle.load(f)
        found_keys.update(trees.keys())

In [8]:
len(found_keys)

5760

In [9]:
for fname in os.listdir('./data/trees3'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/trees3',fname), 'rb') as f:
            trees = pickle.load(f)
        found_keys.update(trees.keys())

In [10]:
for fname in os.listdir('./data/trees4'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/trees4',fname), 'rb') as f:
            trees = pickle.load(f)
        found_keys.update(trees.keys())

In [11]:
len(found_keys)

8350

In [12]:
missing_keys = set(cell_ids) - found_keys
missing_cids = list(missing_keys)

In [13]:
len(missing_keys)

30

In [14]:
def filter_and_connect_graph(original_graph, desired_nodes):
    # Step 1: Initially, identify nodes with multiple parents and nodes to keep
    nodes_with_multiple_children = {node for node in original_graph.nodes() if original_graph.out_degree(node) > 1}
    nodes_to_keep = desired_nodes.union(nodes_with_multiple_children).union({-1})  # Include root node

    # Create a copy of the graph to work on
    G = original_graph.copy()

    # Step 2: For nodes not in the keep list, redirect parents to children and remove the node
    for node in list(G.nodes()):  # List conversion to avoid modification during iteration
        if node not in nodes_to_keep:
            parents = list(G.predecessors(node))
            children = list(G.successors(node))
            for parent in parents:
                for child in children:
                    G.add_edge(parent, child)  # Connect parent directly to child
            G.remove_node(node)  # Remove the node after re-connecting

    assert -1 in G.nodes(), "Root node not found in the graph"
    return G

In [18]:
import networkx as nx
import navis
import cloudvolume as cv
import networkx as nx
import numpy as np

from core.Tree import Tree
from core.Branch import BranchSeq

def process_chunk(cid, vol, syn_df):
    nrns = vol.mesh.get(cid, as_navis=True)
    nrns = navis.simplify_mesh(nrns, F=1/3, parallel=True)
    print('Simplified')
    sks = navis.skeletonize(nrns, parallel=True)
    print('Skeletonized')
    sks = navis.heal_skeleton(sks, parallel=True)
    print('Healed')
    sks = navis.prune_twigs(sks, 6000, parallel=True)
    trees = {}
    branches = {}
    for skp in sks:
        syn_pos = np.array(syn_df[syn_df['post_pt_root_id'] == skp.id][['x', 'y', 'z']].values) * np.array([8, 8, 33])
        pre_cell_ids = np.array(syn_df[syn_df['post_pt_root_id'] == skp.id]['pre_pt_root_id'].values)
        syn_ids = np.array(syn_df[syn_df['post_pt_root_id'] == skp.id].index)
        
        segment_length_dict = {node: navis.segment_length(skp, seg)/1000 for seg in skp.segments for node in seg}
        node_ids,_ = skp.snap(syn_pos)
        RG = skp.get_graph_nx().reverse()

        root_pos = np.array(skp.nodes.iloc[skp.root][['x', 'y', 'z']].values[0])

        RG.add_node(-1, pos=root_pos/1000)
        for r in skp.root:
            RG.add_edge(r, -1)
        G = filter_and_connect_graph(RG, set(node_ids))
        G.graph['cell_id'] = skp.id
        branch_lengths = [segment_length_dict.get(node, 0) for node in list(G.nodes)]
        # set node attributes of G with syn_pos, cell_types, and pre_cell_ids
        nx.set_node_attributes(G, dict(zip(node_ids, syn_pos/1000)), 'pos')
        nx.set_node_attributes(G, dict(zip(node_ids, len(node_ids)*['4P'])), 'cell_type')
        nx.set_node_attributes(G, dict(zip(node_ids, pre_cell_ids)), 'pre_cell_id')
        nx.set_node_attributes(G, dict(zip(list(G.nodes), branch_lengths)), 'branch_length')

        tree = Tree(G.reverse(), root_id=-1)
        trees[skp.id] = tree
        branches[skp.id] = [BranchSeq(path, tree.graph, (cid, j)) for j, path in enumerate(tree.get_paths())]
    
    return trees, branches

In [None]:
navis.patch_cloudvolume()
vol = cv.CloudVolume('precomputed://gs://h01-release/data/20210601/c3', use_https=True, progress=True, parallel=True)

syn_df = pd.read_csv('/home/saarthak/microns/data/syn_df.csv')

In [16]:
extra_trees = {}
extra_branches = {}

In [20]:
chunk_size = 5
start = 4
for s in range(start, int(np.ceil(len(missing_cids)/chunk_size))):
    chunk = missing_cids[s*chunk_size:(s+1)*chunk_size]
    print('Starting chunk', s)
    trees, branches = process_chunk(chunk, vol, syn_df)
    extra_trees.update(trees)
    extra_branches.update(branches)
    print('Finished chunk', s)

Starting chunk 4


Simplifying:   0%|          | 0/5 [00:00<?, ?it/s]

Simplified


Skeletonizing:   0%|          | 0/5 [00:00<?, ?it/s]

Skeletonized


Healing:   0%|          | 0/5 [00:00<?, ?it/s]

Healed


Pruning:   0%|          | 0/5 [00:00<?, ?it/s]

Finished chunk 4
Starting chunk 5


Simplifying:   0%|          | 0/5 [00:00<?, ?it/s]

Simplified


Skeletonizing:   0%|          | 0/5 [00:00<?, ?it/s]

Skeletonized


Healing:   0%|          | 0/5 [00:00<?, ?it/s]

Healed


Pruning:   0%|          | 0/5 [00:00<?, ?it/s]

Finished chunk 5


In [21]:
bad_chunks = [0, 3]

In [26]:
indiv_ids = [x for i in bad_chunks for x in missing_cids[i*chunk_size:(i+1)*chunk_size] ]
start = 9
for s in range(start, len(indiv_ids)):
    iid = indiv_ids[s]
    print('Starting at', s, 'with id', iid)
    trees, branches = process_chunk([iid], vol, syn_df)
    extra_trees.update(trees)
    extra_branches.update(branches)
    print('Finished id', iid)

Starting at 9 with id 48656505068


Simplifying:   0%|          | 0/1 [00:00<?, ?it/s]

Simplified


Skeletonizing:   0%|          | 0/1 [00:00<?, ?it/s]

Skeletonized


Healing:   0%|          | 0/1 [00:00<?, ?it/s]

Healed


Pruning:   0%|          | 0/1 [00:00<?, ?it/s]

Finished id 48656505068


In [None]:
bad_ids = [5131361032, 6644463791, 38613441248, 3425979752]

In [27]:
# save the extra trees and branches
with open('./data/trees3/trees_200.pkl', 'wb') as f:
    pickle.dump(extra_trees, f)
with open('./data/trees3/branches_200.pkl', 'wb') as f:
    pickle.dump(extra_branches, f)

In [None]:
# bad_ids = []
# # remove bad_ids from cell_df
# cell_df = cell_df[~cell_df['pt_root_id'].isin(bad_ids)]
# cell_df.to_csv('./data/post_ids.csv', index=False)

In [28]:
# collate into one branches and trees dict
all_trees = {}
all_branches = {}

for fname in os.listdir('./data/trees'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/trees',fname), 'rb') as f:
            trees = pickle.load(f)
        all_trees.update(trees)

for fname in os.listdir('./data/trees2'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/trees2',fname), 'rb') as f:
            trees = pickle.load(f)
        all_trees.update(trees)

for fname in os.listdir('./data/trees3'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/trees3',fname), 'rb') as f:
            trees = pickle.load(f)
        all_trees.update(trees)

for fname in os.listdir('./data/trees4'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/trees4',fname), 'rb') as f:
            trees = pickle.load(f)
        all_trees.update(trees)

for fname in os.listdir('./data/branches'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/branches',fname), 'rb') as f:
            branches = pickle.load(f)
        all_branches.update(branches)

for fname in os.listdir('./data/branches2'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/branches2',fname), 'rb') as f:
            branches = pickle.load(f)
        all_branches.update(branches)

for fname in os.listdir('./data/branches3'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/branches3',fname), 'rb') as f:
            branches = pickle.load(f)
        all_branches.update(branches)

for fname in os.listdir('./data/branches4'):
    if fname.endswith('.pkl'):
        with open(os.path.join('./data/branches4',fname), 'rb') as f:
            branches = pickle.load(f)
        all_branches.update(branches)

with open('./data/all_trees.pkl', 'wb') as f:
    pickle.dump(all_trees, f)

with open('./data/all_branches.pkl', 'wb') as f:
    pickle.dump(all_branches, f)

In [29]:
from collections import defaultdict

total_branches = 0
for cid in all_branches:
    total_branches += len(all_branches[cid])
print('Total branch count:', total_branches)
# create a filtered version of all_branches

filt_branch_ct = 0
filt_branches = defaultdict(list)
for cid in all_branches:
    valid_branches = []
    for branch in all_branches[cid]:
        if len(branch.cell_id_sequence['collapsed']) >= 3:
            valid_branches.append(branch)
    if len(valid_branches) > 0:
        filt_branches[cid] = valid_branches
        filt_branch_ct += len(valid_branches)

print('Filtered branch count:', filt_branch_ct)
with open('./data/filt_branches.pkl', 'wb') as f:
    pickle.dump(filt_branches, f)

Total branch count: 1197042
Filtered branch count: 23212


In [30]:
all_inputs = set()
for cid in filt_branches:
    for branch in filt_branches[cid]:
        all_inputs.update(branch.cell_id_sequence['collapsed'])

print('Total input count:', len(all_inputs))

# save as a pandas dataframe with list in column 'pt_root_id'
all_inputs = list(all_inputs)
all_inputs = pd.DataFrame({'pt_root_id': all_inputs})
all_inputs.to_csv('./data/pre_ids.csv', index=False)

Total input count: 100826
