In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import torch
from visualize._utils import draw_connectome
from utils import ROOT_DIR, NEURONS_302
from torch_geometric.data import Data
from tests.leandro.plots import *
from data._main import get_dataset
from omegaconf import OmegaConf

In [4]:
# load the connectome data
graph_tensors = torch.load(
    os.path.join(ROOT_DIR, "data", "processed", "connectome", "graph_tensors.pt")
)

# make the graph
graph = Data(**graph_tensors)

print('Graph attributes: {}'.format(graph.keys))
print('Num. nodes: {}, Num. edges: {}, Num. node features: {}'.format(graph.num_nodes, graph.num_edges, graph.num_node_features))

Graph attributes: ['idx_to_neuron', 'node_type', 'n_id', 'edge_attr', 'edge_index', 'y', 'num_classes', 'pos', 'x']
Num. nodes: 302, Num. edges: 4396, Num. node features: 1024


In [5]:
neuron_to_idx = {u: v for (v, u) in graph.idx_to_neuron.items()}
idx_to_neuron = {v: u for (v, u) in graph.idx_to_neuron.items()}

In [6]:
def get_connected_neurons(graph, neuron_name, index=False):

    neuron_to_idx = {u: v for (v, u) in graph.idx_to_neuron.items()}

    # Get the edge indices corresponding to the given node
    node_index = neuron_to_idx[neuron_name]
    connected_edge_indices = (graph.edge_index[0] == node_index).nonzero()

    # Get the connected nodes by extracting the second row of the connected_edge_indices
    connected_nodes = graph.edge_index[1, connected_edge_indices].flatten()
    
    if index == True:
        return connected_nodes.detach().numpy()
    else:
        return np.array([idx_to_neuron[idx] for idx in connected_nodes.detach().numpy()])

print(graph.edge_index.shape) # Row 0: source, Row 1: target

print(graph.edge_attr.shape) # Column 0: gap junction connections, Column 1: chemical synapse connections

torch.Size([2, 4396])
torch.Size([4396, 2])


In [7]:
import base64
from IPython.display import Image, display
import requests
from tests.leandro.hierarchical_clustering_analysis import load_reference

class MiniGraph:
    def __init__(self, direction='TD', group_by='four'):
        self.graph = f"""graph {direction};\n"""
        self.nodes = []
        self.ref_dict = load_reference(group_by=group_by)

    def add_nodes(self, neuron, connected_neurons):
        assert isinstance(neuron, str), 'Neuron must be a string'
        assert isinstance(connected_neurons, list), 'Connected_neurons must be a list'
        # Add lines to the graph with the specified color
        self.graph += '    {} --> {};\n'.format(neuron, ' & '.join(connected_neurons))
        # Add the neuron to the list of nodes
        self.nodes.append(neuron)
        # Add the connected neurons to the list of nodes
        self.nodes.extend(connected_neurons)
        # Remove duplicates
        self.nodes = list(set(self.nodes))

    def add_color(self):
        self.graph += '    classDef I fill: #FF1F5B, stroke:#000000, stroke-width:1px;\n'
        self.graph += '    classDef M fill: #00CD6C, stroke:#000000, stroke-width:1px;\n'
        self.graph += '    classDef S fill: #009ADE, stroke:#000000, stroke-width:1px;\n'
        self.graph += '    classDef P fill: #AF58BA, stroke:#000000, stroke-width:1px;\n'
        self.graph += '    classDef MI fill: #FFC61E, stroke:#000000, stroke-width:1px;\n'
        self.graph += '    classDef SM fill: #F28522, stroke:#000000, stroke-width:1px;\n'
        self.graph += '    classDef SI fill: #A0B1BA, stroke:#000000, stroke-width:1px;\n'
        self.graph += '    classDef SMI fill: #A6761D, stroke:#000000, stroke-width:1px;\n'
        self.graph += '    classDef U fill: #E9002D, stroke:#000000, stroke-width:1px;\n'

        for neuron_name in self.nodes:
            self.graph += f'    class {neuron_name} {self.ref_dict[neuron_name]};\n'

    def display(self, save=False, filename='graph.png'):
        self.add_color()
        graphbytes = self.graph.encode("ascii")
        base64_bytes = base64.b64encode(graphbytes)
        base64_string = base64_bytes.decode("ascii")
        image_url = "https://mermaid.ink/img/" + base64_string
        display(Image(url=image_url))

        if save:
            # Save the image from the URL
            response = requests.get(image_url)
            with open(filename, 'wb') as f:
                f.write(response.content)
            print(f"Graph saved as {filename}")

    def print(self):
        print(self.graph)

In [8]:
num_conexoes = []

for neuron_name in NEURONS_302:
    num_conexoes.append(get_connected_neurons(graph, neuron_name, index=False))

# Name of all neurons with N connections
N = 1
print('Neurons with {} connections: {}'.format(N, [NEURONS_302[i] for i, x in enumerate(num_conexoes) if len(x) == N]))

Neurons with 1 connections: ['AS7', 'AS8', 'AS9', 'DD5', 'M3R', 'SABVL']


In [12]:
get_connected_neurons(graph, 'I4', index=False)

array(['I2R', 'I5', 'M2L', 'M2R', 'M3L', 'M3R', 'M4', 'MI', 'NSML',
       'NSMR'], dtype='<U4')

In [10]:
minigraph = MiniGraph()
minigraph.add_nodes('SABVL', ['RIGR'])
minigraph.add_nodes('RIGR', ['ADEL', 'AFDR', 'AIBL', 'AIZR', 'ALNR', 'AQR', 'AVDL', 'AVER',
       'AVKL', 'AVKR', 'BAGL', 'DVA', 'DVC', 'OLLR', 'OLQDR', 'OLQVR',
       'PVT', 'RIAR', 'RIBR', 'RIR', 'RMFL', 'RMHL', 'URXR', 'URYDR',
       'URYVR', 'VD1'])
minigraph.display()
#minigraph.print()

In [13]:
config = OmegaConf.load("/home/lrvnc/Projects/worm-graph/conf/dataset.yaml")
dataset = get_dataset(config)

for wormID, data in dataset.items():
    neuron_names = [value for key, value in data['slot_to_named_neuron'].items()]
    for target in ['M3L', 'NSML']:
        if target in neuron_names:
            print('WormID: {}, Target: {}'.format(wormID, target))

Chosen dataset(s): ['Flavell2023']
Num. worms: 50

WormID: worm0, Target: M3L
WormID: worm0, Target: NSML
WormID: worm1, Target: M3L
WormID: worm2, Target: M3L
WormID: worm2, Target: NSML
WormID: worm3, Target: M3L
WormID: worm3, Target: NSML
WormID: worm4, Target: M3L
WormID: worm4, Target: NSML
WormID: worm5, Target: M3L
WormID: worm5, Target: NSML
WormID: worm6, Target: M3L
WormID: worm6, Target: NSML
WormID: worm7, Target: M3L
WormID: worm7, Target: NSML
WormID: worm8, Target: M3L
WormID: worm8, Target: NSML
WormID: worm9, Target: M3L
WormID: worm9, Target: NSML
WormID: worm10, Target: M3L
WormID: worm10, Target: NSML
WormID: worm11, Target: M3L
WormID: worm11, Target: NSML
WormID: worm12, Target: M3L
WormID: worm12, Target: NSML
WormID: worm13, Target: M3L
WormID: worm13, Target: NSML
WormID: worm14, Target: M3L
WormID: worm14, Target: NSML
WormID: worm15, Target: M3L
WormID: worm15, Target: NSML
WormID: worm16, Target: M3L
WormID: worm16, Target: NSML
WormID: worm17, Target: M3L
