In [None]:
from cedne import utils
import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
import copy
import os
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

In [None]:
if not os.path.isdir(utils.OUTPUT_DIR):
    os.makedirs(utils.OUTPUT_DIR)

In [None]:
def simpleaxis(axes, every=False, outward=False):
    if not isinstance(axes, (list, np.ndarray)):
        axes = [axes]
    for ax in axes:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        if (outward):
            ax.spines['bottom'].set_position(('outward', 10))
            ax.spines['left'].set_position(('outward', 10))
        if every:
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
        ax.get_xaxis().tick_bottom()
        ax.get_yaxis().tick_left()
        ax.set_title('')

In [None]:
ntype = ['sensory', 'interneuron', 'motorneuron']
facecolors = ['#FF6F61', '#FFD700', '#4682B4']
ntype_pairs = [(nt1, nt2) for nt1 in ntype for nt2 in ntype]
colors= plt.cm.magma(np.linspace(0,1,len(ntype_pairs)))
type_color_dict = {p:color for (p,color) in zip(ntype_pairs, colors)}

In [None]:
w = utils.makeWorm(chem_only=True)
nn = w.networks["Neutral"]

In [None]:
triad_motifs = utils.return_triads()

In [None]:
nodesize=800
pos = utils.nx.circular_layout(triad_motifs['300'])
fig, ax = plt.subplots(ncols=len(triad_motifs), figsize=(len(triad_motifs)*2, 2), layout='constrained')
for j,mot in enumerate(triad_motifs):
    n = utils.nx.draw_networkx_nodes(triad_motifs[mot], node_size=nodesize, pos=pos, ax=ax[j], node_color='lightgray')
    e = utils.nx.draw_networkx_edges(triad_motifs[mot], node_size=nodesize, pos=pos, ax= ax[j], arrowstyle='->', width=6, arrowsize=40) 
    l = utils.nx.draw_networkx_labels(triad_motifs[mot], pos=pos, ax= ax[j], font_size='xx-large')
    ax[j].set_xlim((-0.75,1.25))
    ax[j].set_ylim((-1.15,1.15))
    ax[j].set_frame_on(False)
    ax[j].spines['top'].set_visible(False)
    ax[j].spines['bottom'].set_visible(False)
    ax[j].spines['left'].set_visible(False)
    ax[j].spines['right'].set_visible(False)
    ax[j].set_title(mot, fontsize='xx-large')
fig.set_facecolor('none')
plt.show()

In [None]:
num_graphs = {}
conn_types = {}
ntype_motif = {}
for mot in triad_motifs:
    #utils.nx.draw_circular(triad_motifs[m], ax=ax, with_labels=True, node_color='gray', node_size=150)
    conn_types[mot] = {n:0 for n in ntype_pairs}
    all_matches = nn.search_motifs(triad_motifs[mot])
    num_graphs[mot] = len(all_matches)
    ntype_motif[mot] = {m:{n:0 for n in ntype} for m in triad_motifs[mot].nodes}
    for sub in all_matches:
        for motif_edge, network_edge in sub.items():
            if all(n.type in ntype for n in network_edge):
                for m,n in zip(motif_edge, network_edge):
                    ntype_motif[mot][m][n.type] +=1
                conn = (network_edge[0].type, network_edge[1].type)
                conn_types[mot][conn] +=1

In [None]:
def hierarchical_alignment(conn_types_mot):
    feedforward =  conn_types_mot[('sensory', 'interneuron')] + conn_types_mot[('sensory', 'motorneuron')] + conn_types_mot[('interneuron', 'motorneuron')] 
    feedback = conn_types_mot[('interneuron', 'sensory')] + conn_types_mot[('motorneuron', 'interneuron')] + conn_types_mot[('motorneuron', 'sensory')]
    lateral = 0#conn_types_mot[('sensory', 'sensory')] + conn_types_mot[('interneuron', 'interneuron')] + conn_types_mot[('motorneuron', 'motorneuron')]
    # return (feedforward+lateral)/(feedforward+lateral+feedback)
    return (feedforward-feedback)/(feedforward+feedback+lateral)

In [None]:
from cedne.utils import OUTPUT_DIR


ha = [hierarchical_alignment(conn_types[mot]) if not mot in ['003'] else np.nan for mot in triad_motifs ]

def fig2array(fig):
    canvas = FigureCanvas(fig)
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    width, height = int(width), int(height)
    return np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8).reshape(height, width, 3)

nodesize=300
edgewidth = 2
arrowsize=10
figs = []
pos = utils.nx.circular_layout(triad_motifs['300'])
for mot in triad_motifs:
    f, ax = plt.subplots(figsize=(1.5,1.5), dpi=300)
    n = utils.nx.draw_networkx_nodes(triad_motifs[mot], node_size=nodesize, pos=pos, ax=ax, node_color='lightgray')
    e = utils.nx.draw_networkx_edges(triad_motifs[mot], node_size=nodesize, pos=pos, ax= ax, arrowstyle='->', width=edgewidth, arrowsize=arrowsize) 
    l = utils.nx.draw_networkx_labels(triad_motifs[mot], pos=pos, ax= ax, font_size='xx-large')
    ax.set_xlim((-0.75,1.25))
    ax.set_ylim((-1.15,1.15))
    ax.set_frame_on(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    f.set_facecolor('none')
    figs.append(f)
    
f, ax = plt.subplots(figsize=(20,3), layout='constrained')
ax.scatter(np.arange(len(ha)), ha, color='k', s=100)
ax.set_xticks(np.arange(len(ha)), list(triad_motifs.keys()), fontsize='xx-large', rotation=45)
ax.tick_params(axis='y', labelsize='xx-large')
simpleaxis(ax)
ax.set_xlim(xmin=0)
ax.set_ylabel("Hierarchical\nAlignment", fontsize='x-large')

# create offset images from the plots
offset_images = [OffsetImage(fig2array(fig), zoom=0.2) for fig in figs]

# create annotation boxes
ax.set_xticks(np.arange(len(triad_motifs)), np.arange(len(triad_motifs))+1,fontsize='x-large')
# ax.set_xticklabels([mot for mot in triad_motifs], fontsize='x-large')
ax.axhline(y=np.nanmax(ha), linestyle='--', color='gray')

plt.tick_params(axis='y', labelsize='x-large')
xticks = ax.get_xticks()
ax.set_xlim((-0.5, len(triad_motifs)-0.5))
ax.set_yticks([0,0.4,0.8])
# xtick_labels = [ax.get_xticklabels()[i].get_position() for i in range(len(ax.get_xticklabels()))]
annotation_boxes = [AnnotationBbox(offset_image, ((np.arange(len(triad_motifs))[i]+0.5)/len(triad_motifs), -0.4), frameon=False, xycoords='axes fraction') for i, offset_image in enumerate(offset_images)]
# add the annotation boxes to the plot
for annotation_box in annotation_boxes:
    ax.add_artist(annotation_box)
plt.savefig(OUTPUT_DIR + "Hierarchical_alignment_triads.svg", transparent=True)
plt.show()

In [None]:
piesize=0.3
edgewidth = 4
arrowsize=20
pos = utils.nx.circular_layout(triad_motifs['300'])
for mot in triad_motifs:
    # pos = utils.nx.kamada_kawai_layout(triad_motifs[mot])
    # pos = utils.nx.circular_layout(triad_motifs[mot])
    color_dict = {p:facecolors for p in pos.keys()}
    alpha_dict = {p:1 for p in pos.keys()}
    pie_division = {m: [ntype_motif[mot][m][n]/sum(ntype_motif[mot][m].values()) for n in ntype if sum(ntype_motif[mot][m].values())] for m in triad_motifs[mot].nodes}

    if all ([pie_division[m] for m in pie_division.keys()]):
        f, ax = plt.subplots(figsize=(2,2), layout='constrained')
        utils.nx.draw_networkx_edges(triad_motifs[mot], pos=pos, node_size=600, connectionstyle='arc3', arrowstyle='->', width=edgewidth, arrowsize=arrowsize)
        for n,p in pos.items():
            utils.plot_pie(n=n, center=p, ax=ax, color_dict=color_dict, alpha_dict=alpha_dict, pie_division=pie_division[n], piesize=piesize)
        # utils.nx.draw_networkx_labels(hm, pos=pos)
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_aspect('equal')
        ax.axis('off')  # Optionally turn off axis
        f.suptitle(mot)
        plt.show()
    else:
        print(f"No edges found for {mot}")
    plt.savefig(f'{utils.OUTPUT_DIR}/motif_ntype_division-{mot}.svg', transparent=True)
    plt.close()

## Random graph search

In [None]:
num_graphs_random = {}
graph_specs_random = {}

nrgraphs=50
for mot in triad_motifs:
    graph_specs_random[mot] = {'conn_type':[], 'ntype':[]}
    num_graphs_random[mot] = []
    #utils.nx.draw_circular(triad_motifs[m], ax=ax, with_labels=True, node_color='gray', node_size=150)
    for g in range(nrgraphs):
        conn_types_random = {}
        ntype_motif_random = {}
        nnr = utils.randomize_graph(nn)

        conn_types_random[mot] = {n:0 for n in ntype_pairs}
        all_matches = nnr.search_motifs(triad_motifs[mot])
        num_graphs_random[mot].append(len(all_matches))
        ntype_motif_random[mot] = {m:{n:0 for n in ntype} for m in triad_motifs[mot].nodes}

        for sub in all_matches:
            for motif_edge, network_edge in sub.items():
                if all(n.type in ntype for n in network_edge):
                    for m,n in zip(motif_edge, network_edge):
                        ntype_motif_random[mot][m][n.type] +=1
                    conn = tuple([network_edge[0].type, network_edge[1].type])
                    conn_types_random[mot][conn] +=1
        
        graph_specs_random[mot]['conn_type'].append(copy.deepcopy(conn_types_random))
        graph_specs_random[mot]['ntype'].append(copy.deepcopy(ntype_motif_random))

In [None]:


def fig2array(fig):
    canvas = FigureCanvas(fig)
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    width, height = int(width), int(height)
    return np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8).reshape(height, width, 3)

nodesize=300
figs = []
pos = utils.nx.circular_layout(triad_motifs['300'])
for mot in triad_motifs:
    f, ax = plt.subplots(figsize=(1.25,1.25), dpi=300)
    n = utils.nx.draw_networkx_nodes(triad_motifs[mot], node_size=nodesize, pos=pos, ax=ax, node_color='lightgray')
    e = utils.nx.draw_networkx_edges(triad_motifs[mot], node_size=nodesize, pos=pos, ax= ax, arrowstyle='->', width=2, arrowsize=10) 
    l = utils.nx.draw_networkx_labels(triad_motifs[mot], pos=pos, ax= ax, font_size='xx-large')
    ax.set_xlim((-0.75,1.25))
    ax.set_ylim((-1.15,1.15))
    ax.set_frame_on(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    f.set_facecolor('none')
    figs.append(f)


f, ax = plt.subplots(figsize=(16,4), layout='constrained')

frac_num_graphs_random = {mot:num_graphs_random[mot]/np.sum([num_graphs_random[mot1] for mot1 in triad_motifs],axis=0) for mot in triad_motifs}
frac_num_graphs = {mot:num_graphs[mot]/np.sum([num_graphs[mot1] for mot1 in triad_motifs]) for mot in triad_motifs}

zscore_sig_single = [(frac_num_graphs[mot]- np.mean(frac_num_graphs_random[mot]))/np.std(frac_num_graphs_random[mot]) for mot in triad_motifs]
zscore_sig_thres = 2.58

for j,mot in enumerate(triad_motifs):
    ax.scatter([j]*len(num_graphs_random[mot]), y=num_graphs_random[mot], color='gray', alpha=0.2)
    
zscore_sig_single_ind = np.where(np.abs(zscore_sig_single)>zscore_sig_thres)[0]

ax.scatter(range(len(triad_motifs)), [num_graphs[mot] for mot in triad_motifs], color='purple')
ax.scatter(range(len(triad_motifs)), y=[np.mean(num_graphs_random[mot]) for mot in triad_motifs], color='k')

#ax.errorbar(range(len(triad_motifs)), y=[np.mean(num_graphs_random[mot]) for mot in triad_motifs], yerr=[np.std(num_graphs_random[mot]) for mot in triad_motifs],  color='gray', linestyle='None')
ax.set_yscale('log')
ax.set_ylim(ymin=1, ymax=10**8)
# create offset images from the plots
offset_images = [OffsetImage(fig2array(fig), zoom=0.2) for fig in figs]

# create annotation boxes
ax.set_xticks(range(len(triad_motifs)))
ax.set_xticklabels([mot for mot in triad_motifs], fontsize='xx-large')

plt.tick_params(axis='y', labelsize='xx-large')
xticks = ax.get_xticks()
ax.set_xlim((-0.5, len(triad_motifs)-0.5))
xtick_labels = [ax.get_xticklabels()[i].get_position() for i in range(len(ax.get_xticklabels()))]
annotation_boxes = [AnnotationBbox(offset_image, ((np.arange(len(triad_motifs))[i]+0.5)/len(triad_motifs), -0.4), frameon=False, xycoords='axes fraction') for i, offset_image in enumerate(offset_images)]
# add the annotation boxes to the plot
for annotation_box in annotation_boxes:
    ax.add_artist(annotation_box)

for x in zscore_sig_single_ind:
    ax.text(x=x, y=10**8, s="**", color='k')

simpleaxis(ax)
f.supylabel("Number of matched networks", fontsize='xx-large')
plt.savefig(f"Single_triads_50_random_networks.svg", format='svg', transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
ntype_pairs = [(m,n) for m in ntype for n in ntype]

In [None]:
ntype_pairs

In [None]:
gmr = np.array([[graph_specs_random[mot]['conn_type'][j][mot][conn] for j in range((nrgraphs))] for conn in conn_types[mot]])

In [None]:
(np.array(actual_connt) - np.array(rand_connt_mean))/np.array(rand_connt_std)

In [None]:
zscore_sig_thres = 2.58
for mot in triad_motifs:
    
    gmr = np.array([[graph_specs_random[mot]['conn_type'][j][mot][conn] for j in range(nrgraphs)] for conn in conn_types[mot]])
    gmr = gmr/np.sum(gmr, axis=0)
    gm = [conn_types[mot][conn] for conn in conn_types[mot]]
    gm = gm/np.sum(gm)
    conns, actual_connt, rand_connt_mean, rand_connt_std = zip(*[(conn, gm[j], np.mean(gmr, axis=1)[j], np.std(gmr, axis=1)[j]) for j,conn in enumerate(conn_types[mot])])

    actual_connt, rand_connt_mean, rand_connt_std  = np.array(actual_connt), np.array(rand_connt_mean), np.array(rand_connt_std)
    if not (np.isnan(actual_connt).any() or np.isnan(rand_connt_mean).any()):
        f, ax = plt.subplots(figsize=(2,2), layout='constrained')
        ax.bar(np.arange(6)-0.25, actual_connt, color='purple', width=0.25)
        ax.bar(np.arange(6), rand_connt_mean, color='gray', width=0.25)
        ax.errorbar(np.arange(6), y=rand_connt_mean, yerr=rand_connt_std, color='gray', linestyle='none')
        zind = (actual_connt - rand_connt_mean)/rand_connt_std
        for x in np.where(zind>zscore_sig_thres)[0]:
            ax.text(x-0.25, ax.get_ylim()[1], s="*", color='k')
        ax.set_xticks(np.arange(6), conns, rotation=45, ha='right')
        simpleaxis(ax)
        f.suptitle(mot)
        plt.show()

In [None]:
piesize=0.3
for mot in triad_motifs:
    # pos = utils.nx.kamada_kawai_layout(triad_motifs[mot])
    pos = utils.nx.circular_layout(triad_motifs[mot])

    color_dict = {p:facecolors for p in pos.keys()}
    alpha_dict = {p:1 for p in pos.keys()}
    pie_division = {m: [graph_specs_random['ntype'][mot][m][n]/sum(graph_specs_random['ntype'][mot][m].values()) for n in ntype if sum(graph_specs_random['ntype'][mot][m].values())] for m in triad_motifs[mot].nodes}

    if all ([pie_division[m] for m in pie_division.keys()]):
        f, ax = plt.subplots(figsize=(2,2), layout='constrained')
        utils.nx.draw_networkx_edges(triad_motifs[mot], pos=pos, node_size=1200, connectionstyle='arc3', arrowstyle='->')
        for n,p in pos.items():
            utils.plot_pie(n=n, center=p, ax=ax, color_dict=color_dict, alpha_dict=alpha_dict, pie_division=pie_division[n], piesize=piesize)
        # utils.nx.draw_networkx_labels(hm, pos=pos)
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_aspect('equal')
        ax.axis('off')  # Optionally turn off axis
        f.suptitle(mot)
        plt.show()
    else:
        print(f"No edges found for {mot}")
    plt.savefig(f'{utils.OUTPUT_DIR}/motif_ntype_division-{mot}.svg', transparent=True)
    plt.close()

In [None]:
single_motif_fractions = {}
for mot in triad_motifs:
    single_motif_fractions[mot] = {}
    single_motif_fractions[mot] = np.array(num_graphs_random[mot])/(np.array(num_graphs[mot]))

In [None]:
motif = triad_motifs['030T']

In [None]:
motif = utils.nx.relabel_nodes(motif, {1:1, 2:3, 3:2})

In [None]:
num_graphs = {}
mappings = [(3,1)]
max_chain_length = 3
for mapping in mappings:
    num_graphs[mapping] = []
    for l in range(2,max_chain_length+1):
        hm = utils.make_hypermotifs(motif, l, [mapping])
        all_ffgs = nn.search_motifs(hm)
        num_graphs[mapping].append(len(all_ffgs))

In [None]:
colors = ['blue', 'gray']
f, ax = plt.subplots(figsize=(2.5,2.5), layout='constrained')
for m, mapping in enumerate(mappings):
    ax.scatter(np.arange(1,max_chain_length), num_graphs[mapping], color=colors[m], label=mapping)
    # ax.scatter(np.arange(1,len(num_graphs[mapping])+1), num_graphs[mapping], color='blue')
ax.set_xticks(np.arange(1,max_chain_length+1))
ax.set_yticks((0,5000,10000), ('0','5k','10k'))
# ax.yaxis.major.formatter._useMathText = True
ax.set_ylabel("# Networks")
ax.set_xlabel("# chained FFLs")
simpleaxis(ax)
f.legend(loc='upper right', ncols=1, bbox_to_anchor=(1.1, 1))
# plt.savefig('FFL-chains-C_elegans.svg')
plt.show()

In [None]:
conn_types = {n:0 for n in ntype_pairs}
ntype_motif = {m:{n:0 for n in ntype} for m in hm.nodes}
for sub in all_ffgs:
    for motif_edge, network_edge in sub.items():
        if all(n.type in ntype for n in network_edge):
            for m,n in zip(motif_edge, network_edge):
                ntype_motif[m][n.type] +=1
            conn = (network_edge[0].type, network_edge[1].type)
            conn_types[conn] +=1

In [None]:
piesize=0.08
pos = utils.nx.kamada_kawai_layout(hm)
color_dict = {p:facecolors for p in pos.keys()}
alpha_dict = {p:1 for p in pos.keys()}
pie_division = {m: [ntype_motif[m][n]/sum(ntype_motif[m].values()) for n in ntype] for m in hm.nodes}

f, ax = plt.subplots(figsize=(8,8))
utils.nx.draw_networkx_edges(hm, pos=pos, node_size=1200, connectionstyle='arc3', arrowstyle='->')
for n,p in pos.items():
    utils.plot_pie(n=n, center=p, ax=ax, color_dict=color_dict, alpha_dict=alpha_dict, pie_division=pie_division[n], piesize=piesize)
# utils.nx.draw_networkx_labels(hm, pos=pos)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_aspect('equal')
ax.axis('off')  # Optionally turn off axis
# plt.savefig('motif_ntype_division.svg')


In [None]:
count_edgelabels.items()

In [None]:
from cedne.utils import OUTPUT_DIR


hseq = hm
mot_edgelabels = {node:[] for node in hseq}
for ffg in all_ffgs:
    nodelist = {node:None for node in hseq}
    for med, ned in ffg.items():
        for m,n in zip(med, ned):
            nodelist[m] = n.name
    for node in nodelist:
        mot_edgelabels[node].append(nodelist[node])

count_edgelabels = {node:[] for node in hseq}
for node in hseq:
    count_edgelabels[node] = Counter(mot_edgelabels[node])

f, ax = plt.subplots(nrows=len(count_edgelabels), figsize=(40,2*len(count_edgelabels)), layout='constrained', sharey=True)
for n, (node, d) in enumerate(sorted(count_edgelabels.items(), key=lambda x:x[0])):
    name, height = zip(*[(name, height) for (name, height) in sorted(d.items(), key=lambda x:x[0])])
    print(sorted(d.items(), key=lambda x:x[1], reverse=True))
    print((d['AVAL'] + d['AVAR'] + d['AVBL'] + d['AVBR'] + d['AVDL'] + d['AVDR'] + d['AVEL'] + d['AVER'])/np.sum(height))
    ax[n].bar(name, height/np.sum(height), color='gray')
    ax[n].set_xticks(range(len(name)), name, rotation=45, fontsize='xx-large')
    ax[n].tick_params(axis='y', labelsize='xx-large')
    utils.simpleaxis(ax[n])
    ax[n].set_title(node, fontsize='xx-large')
f.supylabel("Fraction of matched neurons for node", fontsize='xx-large')
f.supxlabel("Neuron name", fontsize='xx-large')
plt.savefig(OUTPUT_DIR + "/HierarchicalSequence-neuronfractions.svg")
plt.show()

In [None]:
hierarchical_alignment(conn_types)

In [None]:
motif_edge, hm.nodes

In [None]:
sequential_hierarchy_edges = []
sequential_hierarchy_nodes = []
node_id = {n:[] for n in hm.nodes}
for sub in all_ffgs:
    local_node_id = {n:None for n in hm.nodes}
    for motif_edge, network_edge in sub.items():
        sequential_hierarchy_edges.append(network_edge)
        sequential_hierarchy_nodes+=network_edge
        local_node_id[motif_edge[0]] = network_edge[0].name
        local_node_id[motif_edge[1]] = network_edge[1].name
    for n in hm.nodes:
        node_id[n].append(local_node_id[n])

sequential_hierarchy_edges = set(sequential_hierarchy_edges)
sequential_hierarchy_nodes = set(sequential_hierarchy_nodes)

In [None]:
for n in hm.nodes:
    print(n, Counter(node_id[n]))

In [None]:
seq_hier_count = {ntyp:0 for ntyp in ntype_pairs}
nn_count = {ntyp:0 for ntyp in ntype_pairs}
for e, conn in nn.connections.items():
    if e[0].type in ntype and  e[1].type in ntype:
        nn_count[(e[0].type, e[1].type)]+=1
        if (e[0], e[1]) in sequential_hierarchy_edges:
            seq_hier_count[e[0].type, e[1].type]+=1

for nty in ntype_pairs:
    print(nty, seq_hier_count[nty]/nn_count[nty])

In [None]:
mappings = [(2,1), (3,1)]
num_graphs = {}
conn_types = {}
ntype_motif = {}
max_chain_length = 6
for mapping in mappings:
    num_graphs[mapping] = []
    ntype_motif[mapping] = {}
    conn_types[mapping] = {}
    for l in range(1,max_chain_length+1):
        ntype_motif[mapping][l] = {}
        conn_types[mapping][l] = {n:0 for n in ntype_pairs}
        hm = utils.make_hypermotifs(motif, l, [mapping])
        for m in hm.nodes:
            ntype_motif[mapping][l][m] = {n:0 for n in ntype}
        all_ffgs = nn.search_motifs(hm)
        num_graphs[mapping].append(len(all_ffgs))
        for sub in all_ffgs:
            for motif_edge, network_edge in sub.items():
                if all(n.type in ntype for n in network_edge):
                    for m,n in zip(motif_edge, network_edge):
                        ntype_motif[mapping][l][m][n.type] +=1
                    conn = (network_edge[0].type, network_edge[1].type)
                    conn_types[mapping][l][conn] +=1

In [None]:
colors = ['gray', 'purple']
f, ax = plt.subplots(figsize=(2.5,2.5), layout='constrained')
for m, mapping in enumerate(mappings):
    ax.scatter(np.arange(1,max_chain_length+1), num_graphs[mapping], color=colors[m], label=mapping)
    # ax.scatter(np.arange(1,len(num_graphs[mapping])+1), num_graphs[mapping], color='blue')
ax.set_xticks(np.arange(1,max_chain_length+1))
ax.set_yticks((0,5000,10000), ('0','5k','10k'))
# ax.yaxis.major.formatter._useMathText = True
ax.set_ylabel("# subnetworks")
ax.set_xlabel("# chained FFLs")
simpleaxis(ax)
f.legend(loc='upper right', ncols=1, bbox_to_anchor=(1.1, 1), frameon=False)
plt.savefig('FFL-chains-lengths-C_elegans.svg', transparent=True)
plt.show()

In [None]:
for mapping in mappings:
    for l in range(1, max_chain_length+1):
        print(mapping, l, hierarchical_alignment(conn_types[mapping][l]))

In [None]:
motif_fractions_actual = {}
for mapping in mappings:
    motif_fractions_actual[mapping] = np.array(num_graphs[mapping])/(np.array(num_graphs[mappings[0]]) + np.array(num_graphs[mappings[1]]))
colors = ['gray', 'purple']
f, ax = plt.subplots(figsize=(2.5,2.5), layout='constrained')
for m, mapping in enumerate(mappings):
    ax.scatter(np.arange(1,len(num_graphs[mapping])+1),motif_fractions_actual[mapping] , color=colors[m], label=mapping)
ax.set_xticks(np.arange(1,max_chain_length+1))
ax.set_yticks((0,0.5,1))
# ax.yaxis.major.formatter._useMathText = True
ax.set_ylabel("fraction subnetworks with motif")
ax.set_xlabel("Number of chained FFLs")
simpleaxis(ax)
f.legend(loc='upper center', frameon=False, ncols=2, bbox_to_anchor=(0.55,1.1))
plt.savefig('FFL-chains-C_elegans_fractions.svg', transparent=True)
plt.show()

In [None]:
int(np.log(len(nn.edges)))*len(nn.edges)

In [None]:
piesize=0.08
chain_length = 3
for mapping in mappings:
    hm = utils.make_hypermotifs(motif, chain_length, [mapping])
    pos = utils.nx.kamada_kawai_layout(hm)
    color_dict = {p:facecolors for p in pos.keys()}
    alpha_dict = {p:1 for p in pos.keys()}
    pie_division = {m: [ntype_motif[mapping][m][n]/sum(ntype_motif[mapping][m].values()) for n in ntype] for m in hm.nodes}

    f, ax = plt.subplots(figsize=(8,8))
    utils.nx.draw_networkx_edges(hm, pos=pos, node_size=1200, connectionstyle='arc3', arrowstyle='->')
    for n,p in pos.items():
        utils.plot_pie(n=n, center=p, ax=ax, color_dict=color_dict, alpha_dict=alpha_dict, pie_division=pie_division[n], piesize=piesize)
    # utils.nx.draw_networkx_labels(hm, pos=pos)
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.set_aspect('equal')
    ax.axis('off')  # Optionally turn off axis
    plt.savefig(f'motif_ntype_division_{mapping}.svg', transparent=True)
    plt.show()
    plt.close()


## Randomizing graphs by swapping edges.

In [None]:
num_graphs = {}
mappings = [(2,1), (3,1)]
max_chain_length = 6
nrgraphs = 50
graph_specs = {}
for mapping in mappings:
    num_graphs[mapping] = {}
    graph_specs[mapping] = {}
    for l in range(1,max_chain_length+1):
        num_graphs[mapping][l] = []
        graph_specs[mapping][l] = {'conn_type':[], 'ntype':[]}
        hm = utils.make_hypermotifs(motif, l, [mapping])
        for g in range(nrgraphs):
            nnr = utils.randomize_graph(nn)
            all_ffgs = nnr.search_motifs(hm)
            num_graphs[mapping][l].append(len(all_ffgs))

            conn_types = {n:0 for n in ntype_pairs}
            ntype_motif = {m:{n:0 for n in ntype} for m in hm.nodes}
            for sub in all_ffgs:
                for motif_edge, network_edge in sub.items():
                    if all(n.type in ntype for n in network_edge):
                        for m,n in zip(motif_edge, network_edge):
                            ntype_motif[m][n.type] +=1
                        conn = tuple(sorted([network_edge[0].type, network_edge[1].type]))
                        conn_types[conn] +=1
            graph_specs[mapping][l]['conn_type'].append(copy.deepcopy(conn_types))
            graph_specs[mapping][l]['ntype'].append(copy.deepcopy(ntype_motif))

In [None]:
colors = ['gray', 'purple']
f, ax = plt.subplots(figsize=(2.5,2.5), layout='constrained')
for m, mapping in enumerate(mappings):
    ngraphs_mu = np.mean([num_graphs[mapping][l] for l in np.arange(1,max_chain_length+1)], axis=1)
    ngraphs_sigma = np.std([num_graphs[mapping][l] for l in np.arange(1,max_chain_length+1)], axis=1)
    ax.errorbar(np.arange(1,max_chain_length+1), ngraphs_mu, yerr = ngraphs_sigma, color=colors[m])
    ax.scatter(np.arange(1,len(num_graphs[mapping])+1), ngraphs_mu, color=colors[m], label=mapping)
ax.set_xticks(np.arange(1,max_chain_length+1))
ax.set_yticks((0,15000,30000), ('0','15k','30k'))
# ax.yaxis.major.formatter._useMathText = True
ax.set_ylabel("# subnetworks")
ax.set_xlabel("# chained FFLs")
simpleaxis(ax)
f.legend(loc='upper right', ncols=1, bbox_to_anchor=(1.1, 1), frameon=False)
plt.savefig('FFL-chains-C_elegans_randomized_edges.svg', transparent=True)
plt.show()

In [None]:
motif_fractions = {}
for mapping in mappings:
    motif_fractions[mapping] = {}
    for l in np.arange(1,max_chain_length+1):
        motif_fractions[mapping][l] = np.array(num_graphs[mapping][l])/(np.array(num_graphs[mappings[0]][l]) + np.array(num_graphs[mappings[1]][l]))

In [None]:
ngraphs_mu.shape

In [None]:
colors = ['gray', 'purple']
zscore_sig_thres = 2.58
f, ax = plt.subplots(figsize=(2.5,2.5), layout='constrained')
for m, mapping in enumerate(mappings):
    ngraphs_mu = np.mean([motif_fractions[mapping][l] for l in np.arange(1,max_chain_length+1)], axis=1)
    ngraphs_sigma = np.std([motif_fractions[mapping][l] for l in np.arange(1,max_chain_length+1)], axis=1)
    
    ax.errorbar(np.arange(1,max_chain_length+1), ngraphs_mu, yerr = ngraphs_sigma, color=colors[m], alpha=0.5, linestyle='--')
    ax.scatter(np.arange(1,len(num_graphs[mapping])+1),motif_fractions_actual[mapping] , color=colors[m], label=mapping)
    ax.plot(np.arange(1,len(num_graphs[mapping])+1),motif_fractions_actual[mapping] , color=colors[m])

    zscore_sig = np.where(np.abs([(motif_fractions_actual[mapping][l-1]- ngraphs_mu[l-1])/ngraphs_sigma[l-1] for l in np.arange(1,max_chain_length+1)])>zscore_sig_thres)[0]
    for x in zscore_sig:
        ax.text(x=x+1, y=0.8, s="**", color='k')

    # ax.scatter(np.arange(1,len(num_graphs[mapping])+1), ngraphs_mu, color=colors[m], label=mapping)
ax.set_xticks(np.arange(1,max_chain_length+1))
ax.set_yticks((0,0.5,1))
# ax.yaxis.major.formatter._useMathText = True
ax.set_ylabel("Fraction subnetworks with motif")
ax.set_xlabel("# chained FFLs")
simpleaxis(ax)
f.legend(loc='upper left', frameon=False, bbox_to_anchor=(0.15,1.05))
plt.savefig('FFL-chains-C_elegans_randomized_actual_edges_fractions_new.svg', transparent=True)
plt.show()

In [None]:
constraint = utils.nx.constraint(nn)

In [None]:
for nname, cons in sorted(constraint.items(), key=lambda x:x[1]):
    print(nname.name, cons)

In [None]:
for mapping in mappings:
    random_graph = [motif_fractions[mapping][l] for l in np.arange(1,max_chain_length+1)]
    for l in np.arange(1,max_chain_length+1):
        print(l, len(np.where(motif_fractions_actual[mapping][l-1]>random_graph[l-1])[0]))

In [None]:
piesize=0.08
chain_length = 3
for m, mapping in enumerate(mappings):
    print(mapping)
    hm = utils.make_hypermotifs(motif, 3, [mapping])
    pos = utils.nx.kamada_kawai_layout(hm)
    color_dict = {p:facecolors for p in pos.keys()}
    alpha_dict = {p:1 for p in pos.keys()}
    pie_division = {m: [np.mean([graph_specs[mapping][chain_length]['ntype'][j][m][n] for j in range(len(graph_specs[mapping][chain_length]['ntype']))]) for n in ntype] for m in hm.nodes}
    f, ax = plt.subplots(figsize=(8,8))
    utils.nx.draw_networkx_edges(hm, pos=pos, node_size=1200, connectionstyle='arc3', arrowstyle='->')
    for n,p in pos.items():
        utils.plot_pie(n=n, center=p, ax=ax, color_dict=color_dict, alpha_dict=alpha_dict, pie_division=pie_division[n], piesize=piesize)
    # utils.nx.draw_networkx_labels(hm, pos=pos)
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.set_aspect('equal')
    ax.axis('off')  # Optionally turn off axis
    plt.savefig(f'motif_ntype_division_{mapping}_randomized_ntypes.svg', transparent=True)


In [None]:
pres = []
for e in nn2.edges:
    e2 = (nn.neurons[e[0].name], nn.neurons[e[1].name], e[2])
    if(e2 in nn.edges):
        pres.append((e[0].name, e[1].name))
print(len(pres))

In [None]:
actual = utils.nx.triadic_census(nn)
num_g_act = [actual[t] for t in triad_motifs.keys()]

## Randomizing by swapping edges

In [None]:
num_rands = 100
num_g_ran_all = []
for ni in range(num_rands):
    nn2 = utils.randomize_graph(nn, seed=ni)
    ran = utils.nx.triadic_census(nn2)
    num_g_ran = [ran[t] for t in triad_motifs.keys()]
    num_g_ran_all.append(num_g_ran)

In [None]:
f, ax = plt.subplots(figsize=(12,3))
ax.scatter(list(triad_motifs.keys()), num_g_act, color='orange')
ax.scatter(list(triad_motifs.keys()), np.mean(num_g_ran_all, axis=0), color='gray')
ax.errorbar(list(triad_motifs.keys()), np.mean(num_g_ran_all, axis=0), yerr=np.std(num_g_ran_all, axis=0), color='gray', linestyle='None')
ax.set_yscale('log')
plt.show()

## Randomizing by degree sequence

In [None]:
num_rands = 100
num_g_ran_all = []
for ni in range(num_rands):
    nn2 = utils.randomize_graph(nn, seed=ni, mode='configuration-model')
    ran = utils.nx.triadic_census(nn2)
    num_g_ran = [ran[t] for t in triad_motifs.keys()]
    num_g_ran_all.append(num_g_ran)

In [None]:
f, ax = plt.subplots(figsize=(12,3))
ax.scatter(list(triad_motifs.keys()), num_g_act, color='orange')
ax.scatter(list(triad_motifs.keys()), np.mean(num_g_ran_all, axis=0), color='gray')
ax.errorbar(list(triad_motifs.keys()), np.mean(num_g_ran_all, axis=0), yerr=np.std(num_g_ran_all, axis=0), color='gray', linestyle='None')
ax.set_yscale('log')
plt.show()

## Randomizing by number of nodes and edges.

In [None]:
num_rands = 100
num_g_ran_all = []
for ni in range(num_rands):
    nn2 = utils.randomize_graph(nn, seed=ni, mode='num-nodes-edges')
    ran = utils.nx.triadic_census(nn2)
    num_g_ran = [ran[t] for t in triad_motifs.keys()]
    num_g_ran_all.append(num_g_ran)

In [None]:
f, ax = plt.subplots(figsize=(12,3))
ax.scatter(list(triad_motifs.keys()), num_g_act, color='orange')
ax.scatter(list(triad_motifs.keys()), np.mean(num_g_ran_all, axis=0), color='gray')
ax.errorbar(list(triad_motifs.keys()), np.mean(num_g_ran_all, axis=0), yerr=np.std(num_g_ran_all, axis=0), color='gray', linestyle='None')
ax.set_yscale('log')
plt.show()

In [None]:
ntype = ['sensory', 'interneuron', 'motorneuron']
ntype_pairs = set([tuple(sorted([nt1, nt2])) for nt1 in ntype for nt2 in ntype])
colors= plt.cm.magma(np.linspace(0,1,len(ntype_pairs)))
type_color_dict = {p:color for (p,color) in zip(ntype_pairs, colors)}

In [None]:
hm = utils.make_hypermotifs(motif, 1, [(1,1)])
all_ffgs = nn.search_motifs(hm)

In [None]:
hm = utils.make_hypermotifs(motif, 3, [(3,1)])
hm = utils.nx.relabel_nodes(hm, {'1.3-2.1':'2.1', '2.3-3.1':'3.1'})

In [None]:
all_ffgs = nn.search_motifs(hm)

In [None]:
conn_types = {n:0 for n in ntype_pairs}
ntype_motif = {m:{n:0 for n in ntype} for m in hm.nodes}
for sub in all_ffgs:
    for motif_edge, network_edge in sub.items():
        if all(n.type in ntype for n in network_edge):
            for m,n in zip(motif_edge, network_edge):
                ntype_motif[m][n.type] +=1
            conn = tuple(sorted([network_edge[0].type, network_edge[1].type]))
            conn_types[conn] +=1

In [None]:
pie_division = {m: [ntype_motif[m][n]/sum(ntype_motif[m].values()) for n in ntype] for m in hm.nodes}

In [None]:
piesize=0.09
pos = utils.nx.kamada_kawai_layout(hm)
color_dict = {p:facecolors for p in pos.keys()}
alpha_dict = {p:1 for p in pos.keys()}

f, ax = plt.subplots(figsize=(6,6), layout='constrained')
utils.nx.draw_networkx_edges(hm, pos=pos, node_size=1200, connectionstyle='arc3', arrowstyle='->')
for n,p in pos.items():
    utils.plot_pie(n=n, center=p, ax=ax, color_dict=color_dict, alpha_dict=alpha_dict, pie_division=pie_division[n], piesize=piesize)
# utils.nx.draw_networkx_labels(hm, pos=pos)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_aspect('equal')
ax.axis('off')  # Optionally turn off axis
plt.savefig('motif_ntype_ffchain.svg', transparent=True)


In [None]:
utils.loadNeurotransmitters(nn)

In [None]:
conn_ligs = {}
all_ntrs = []
putative_lens = {}
motif_conns = {}
for sub in all_ffgs:
    for motif_edge, network_edge in sub.items():
        if not motif_edge in conn_ligs:
            conn_ligs[motif_edge] = []
            putative_lens[motif_edge] = []
            motif_conns[motif_edge] = []
        neuron_1, neuron_2 = network_edge[0], network_edge[1]
        conns = nn.connections_between(neuron_1, neuron_2, directed=True)
        for e, conn in conns.items():
            if conn.connection_type=='chemical-synapse':
                motif_conns[motif_edge].append(conn)
                conn_edges = ['-'.join(k) for k in conn.putative_neurotrasmitter_receptors if (isinstance(k[0], str) and isinstance(k[1], str))]
                putative_lens[motif_edge].append(len(conn_edges))
                conn_ligs[motif_edge].append(conn_edges)
                all_ntrs+= conn_edges

In [None]:
color_nt = {'Ach': 'lightgreen', 'Dop':'navy', 'Glu': 'darkorange', 'GAB': 'crimson', 'Ser': 'k'}
all_labels = sorted(set(all_ntrs))
color = [color_nt[m[:3]] for m in all_labels]
fig, ax = plt.subplots(nrows=len(conn_ligs.keys()), sharex=True, sharey=True, figsize=(30,12), layout='constrained')
for j,e in enumerate(sorted(conn_ligs.keys())):
    nums = []
    for k in all_labels:
        nums.append(sum([1/len(conn_ligs[e][i]) if k in conn_ligs[e][i] else 0 for i in range(len(conn_ligs[e]))])/len(conn_ligs[e]))
    ax[j].bar(all_labels, nums, color=color)
    simpleaxis(ax[j])
    ax[j].set_title(e, fontsize="xx-large")
    ax[j].set_yticks([0,0.1,0.2], labels=[0,0.1,0.2], fontsize="xx-large")
plt.xticks(rotation=45, ha='right', fontsize="xx-large")
fig.supylabel("Fraction of edges with predicted pair", fontsize='xx-large')
fig.supxlabel("Predicted neurotransmitter-receptor pair", fontsize='xx-large')
plt.savefig("Motif-FFLoop-3-chain.svg", transparent=True)
plt.show()
plt.close()
    # print(sum([1 if k in conn_ligs[e][i] else 0 for i in range(len(conn_ligs[e]))])/len(conn_ligs[e]))
    # conn_ligs[e])
# conn_ligs[motif_edge]

In [None]:
f, ax = plt.subplots(ncols=len(conn_ligs), figsize=(len(conn_ligs)*2, 2), sharex=True, sharey=True, layout='constrained')
for i,k in enumerate(conn_ligs):
    ax[i].hist(putative_lens[k], bins=np.arange(0,18,1), color='gray')
    simpleaxis(ax[i])
    ax[i].set_title(k)
plt.show()

In [None]:
all_labels = sorted(set(all_ntrs))
fig, ax = plt.subplots(nrows=len(conn_ligs.keys()), sharex=True, sharey=True, figsize=(21,9), layout='constrained')
for j,edge in enumerate(conn_ligs.keys()):
    c = Counter(conn_ligs[edge])
    nums = [c[k]/putative_lens[edge] if k in c.keys() else 0 for k in all_labels ]
    ax[j].bar(all_labels, nums, color='gray')
    simpleaxis(ax[j])
    ax[j].set_title(edge)
plt.xticks(rotation=45, ha='right')
# plt.savefig("Motif-FFChain.svg")
plt.show()
plt.close()

In [None]:
int_conns = []
for mconn in motif_conns.keys():
    interesting_conns = [(c._id[0].name,c._id[1].name)  for c in motif_conns[mconn]]
    utils.plot_layered(interesting_conns, nn, nodeColors={}, edgeColors='gray', title=mconn)
    int_conns+=interesting_conns

In [None]:
all_conns= []
for mconn in motif_conns.keys():
    interesting_conns = ((c._id[0].name,c._id[1].name, c._id[2])  for c in motif_conns[mconn])
    all_conns += interesting_conns

In [None]:
len(set(all_conns)) / len(nn.connections.items())

In [None]:
len(nn.connections.items()), len(set(all_conns))

In [None]:
by_category = {}
for n in nn.neurons:
    if not nn.neurons[n].category in by_category: 
        by_category[nn.neurons[n].category] = []
    by_category[nn.neurons[n].category].append(n)

In [None]:
nn_cat = nn.fold_network(by_category, data='clean')

In [None]:
for e in nn_cat.edges:
    print(e)

In [None]:
int_conns = []
for mconn in motif_conns.keys():
    interesting_conns = [(nn_cat.neurons[c._id[0].category], nn_cat.neurons[c._id[1].category],0)  for c in motif_conns[mconn]]
    edge_color_dict = {eid:'k' if eid in interesting_conns else 'lightgray' for eid in nn_cat.edges}
    fig_ent = utils.plot_shell(nn_cat, shells=[[], [nn_cat.neurons[n] for n in nn_cat.neurons]], figsize=(8,8),  width_logbase=2, edge_color_dict=edge_color_dict)
    int_conns+=interesting_conns

In [None]:
nodelist = []
for s in set(int_conns):
    nodelist+= [*s]

In [None]:
len(set(nodelist))

In [None]:
utils.nx.triadic_census(nn)

In [None]:
sorted(motif.nodes) == [*range(1,len(motif.nodes)+1)]

In [None]:
sorted(motif.nodes), [*range(1,len(motif.nodes)+1)]

In [None]:
type(motif)