In [1]:
import networkx as nx
import pandas as pd
import torch_geometric
from torch_geometric.utils import from_networkx
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

In [2]:
neuron_features = pd.read_csv("filtered_neurons.csv")
edge_features = pd.read_csv("filtered_cells.csv")
all_neurons = pd.read_csv("filtered_neurons_all_features.csv")

In [3]:
neuron_features.head()

Unnamed: 0.1,Unnamed: 0,bodyId,type,pre,post,size
0,69718,2251225397,,13,3,6796384
1,53681,1903362067,,26,20,12958457
2,80812,2465106393,,4,2,2684477
3,80160,2460357918,,18,8,18561622
4,80704,2464472373,,8,16,4746459


In [4]:
edge_features.head()

Unnamed: 0,bodyId_pre,bodyId_post,roi,weight
0,357250124,819725372,ICL(R),11
1,357250124,858812398,ICL(R),31
2,357250124,858812398,SCL(R),15
3,391289810,298935111,SLP(R),11
4,391289810,330268940,SLP(R),17


In [5]:
all_neurons.head()

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,bodyId,instance,type,pre,post,downstream,upstream,mito,...,status,cropped,statusLabel,cellBodyFiber,somaRadius,somaLocation,roiInfo,notes,inputRois,outputRois
0,0,69718,2251225397,,,13,3,76,3,2,...,Traced,True,Leaves,,,,"{'OL(R)': {'pre': 13, 'post': 3, 'downstream':...",,"['LO(R)', 'OL(R)']","['LO(R)', 'OL(R)']"
1,1,53681,1903362067,,,26,20,158,20,5,...,Traced,True,Leaves,,,,"{'OL(R)': {'pre': 26, 'post': 20, 'downstream'...",,"['LO(R)', 'OL(R)']","['LO(R)', 'OL(R)']"
2,2,80812,2465106393,,,4,2,16,2,2,...,Traced,True,Leaves,,,,"{'OL(R)': {'pre': 4, 'post': 2, 'downstream': ...",,"['LO(R)', 'OL(R)']","['LO(R)', 'OL(R)']"
3,3,80160,2460357918,,,18,8,145,8,4,...,Traced,True,Leaves,,,,"{'OL(R)': {'pre': 18, 'post': 8, 'downstream':...",,"['LO(R)', 'LOP(R)', 'OL(R)']","['LO(R)', 'LOP(R)', 'OL(R)']"
4,4,80704,2464472373,,,8,16,60,16,1,...,Traced,True,Leaves,,,,"{'OL(R)': {'pre': 8, 'post': 16, 'downstream':...",,"['LO(R)', 'OL(R)']","['LO(R)', 'OL(R)']"


In [6]:
def extract_xyz(location):
    if isinstance(location, str) and location.startswith("("):  
        x, y, z = location.strip("()").split(",")  
        return float(x), float(y), float(z)
    return 0.0, 0.0, 0.0

In [7]:
all_neurons[['x', 'y', 'z']] = all_neurons['somaLocation'].apply(lambda loc: pd.Series(extract_xyz(loc)))
all_neurons.drop(columns=['somaLocation'], inplace=True)
G = nx.DiGraph()  

for _, row in all_neurons.iterrows():
    G.add_node(row['bodyId'], 
               neuron_type=row['type'], 
               pre=row['pre'], 
               post=row['post'], 
               size=row['size'], 
               x=row['x'], 
               y=row['y'], 
               z=row['z'])

for _, row in edge_features.iterrows():
    G.add_edge(row['bodyId_pre'], row['bodyId_post'], 
               weight=row['weight'], 
               roi=row['roi'])

print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")

all_attrs = set()
for _, attr in G.nodes(data=True):
    all_attrs.update(attr.keys())  

for node in G.nodes:
    for attr in all_attrs:
        if attr not in G.nodes[node]: 
            G.nodes[node][attr] = 0  

data = from_networkx(G)
print(data)

Graph has 11549 nodes and 14424 edges.
Data(x=[11549], edge_index=[2, 14424], y=[11549], neuron_type=[11549], pre=[11549], post=[11549], size=[11549], z=[11549], weight=[14424], roi=[14424])
