In [None]:
from cedne import utils
import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
import copy
import os

from cedne import utils
import os
import json
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import matplotlib.cm as cm
import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.decomposition import tucker
from tensorly.decomposition import partial_tucker
from tensorly.tenalg import multi_mode_dot
from tensorly import kruskal_to_tensor
from sklearn.preprocessing import StandardScaler
from scipy.ndimage import gaussian_filter1d

In [None]:
if not os.path.isdir(utils.OUTPUT_DIR):
    os.makedirs(utils.OUTPUT_DIR)

In [None]:
ntype = ['sensory', 'interneuron', 'motorneuron']
facecolors = ['#FF6F61', '#FFD700', '#4682B4']
ntype_pairs = set([tuple(sorted([nt1, nt2])) for nt1 in ntype for nt2 in ntype])
colors= plt.cm.magma(np.linspace(0,1,len(ntype_pairs)))
type_color_dict = {p:color for (p,color) in zip(ntype_pairs, colors)}

In [None]:
w = utils.makeWorm(chem_only=True)
nn_chem = w.networks["Neutral"]

w_both = utils.makeWorm()
nn_both = w_both.networks["Neutral"] 

w_gapjn = utils.makeWorm(gapjn_only=True)
nn_gapjn = w.networks["Neutral"]

In [None]:
triad_motifs = utils.return_triads()
motif = triad_motifs['030T']
motif = utils.nx.relabel_nodes(motif, {1:1, 2:3, 3:2})

In [None]:
hm = utils.make_hypermotifs(motif, 3, [(3,1)])
hm = utils.nx.relabel_nodes(hm, {'1.3-2.1':'2.1', '2.3-3.1':'3.1'})
all_ffgs = nn_both.search_motifs(hm)

In [None]:
edges = sorted(hm.edges)

## Checking if gap junctions are somehow different for different nodes of the sequential hierarchy.

In [None]:
gapjn_by_node = {n:[] for n in hm.nodes}
for j,ffg in enumerate(all_ffgs):
    for edge in edges:
        gapjn_by_node[edge[0]]+= [e[1].name for e in nn_gapjn.neurons[ffg[edge][0].name].get_connections(direction='out')]
        gapjn_by_node[edge[1]]+= [e[1].name for e in nn_gapjn.neurons[ffg[edge][1].name].get_connections(direction='out')]
for key in gapjn_by_node:
    gapjn_by_node[key] = set(gapjn_by_node[key])

In [None]:
by_type = {n:{nt:0 for nt in ntype} for n in hm.nodes}
for key in sorted(gapjn_by_node.keys()):
    for n in gapjn_by_node[key]:
        by_type[key][nn_gapjn.neurons[n].type]+=1

In [None]:
by_type

## Adding time series information

In [None]:
all_edges = []
for ffg in all_ffgs:
    for edge in ffg:
        all_edges.append((ffg[edge][0], ffg[edge][1], 0))
all_edges = list(set(all_edges))

In [None]:
nn_chem_sub = nn_chem.subnetwork(connections=all_edges)

In [None]:
hm = utils.make_hypermotifs(motif, 3, [(3,1)])
hm = utils.nx.relabel_nodes(hm, {'1.3-2.1':'2.1', '2.3-3.1':'3.1'})
all_ffgs = nn_chem_sub.search_motifs(hm)

In [None]:
jsons = {}
for js in os.listdir('/Users/sahilmoza/Documents/Postdoc/Yun Zhang/data/SteveFlavell-NeuroPAL-Cell/Control/'):
    with open ("/Users/sahilmoza/Documents/Postdoc/Yun Zhang/data/SteveFlavell-NeuroPAL-Cell/Control/{}".format(js), 'r') as f:
        jsons['Atanas et al (2023) ' +  js] = json.load(f)

In [None]:
measuredNeurons = {}
neuron_labels = []
for js, p in jsons.items():
    sortedKeys = sorted ([int(x) for x in (p['labeled'].keys())])
    labelledNeurons = {p['labeled'][str(x)]['label']:x for x in sortedKeys if not '?' in p['labeled'][str(x)]['label']} # Removing unsure hits
    measuredNeurons[js] = {m:i for i,m in enumerate(set(labelledNeurons))}
    neuron_labels+=measuredNeurons[js].keys()
neuron_labels = sorted(set(neuron_labels))

In [None]:
for database in jsons.keys():
    ## Subnetwork and optimize
    nn_chem_sub = nn_chem.subnetwork(connections=all_edges)
    hm = utils.make_hypermotifs(motif, 3, [(3,1)])
    hm = utils.nx.relabel_nodes(hm, {'1.3-2.1':'2.1', '2.3-3.1':'3.1'})
    all_ffgs = nn_chem_sub.search_motifs(hm)
    ## Parameter Setup
    inputs = []
    tconstants = [1] *len(nn_chem_sub.nodes)
    input_nodes = [nn_chem_sub.neurons[n] for n in input_neurons]

    weights = {e:1 for e in nn_chem_sub.edges}
    gains = {node:1.0 for node in nn_chem_sub.nodes}
    baselines = {node:0. for node in nn_chem_sub.nodes}
    time_constants = {n:t for n,t in zip(nn_chem_sub.nodes, tconstants)}
    num_timepoints = len(jsons[database]['trace_array'][measuredNeurons[database][list(measuredNeurons[database].keys())[0]]])
    for neuron in nn_chem_sub.neurons:
        if neuron in measuredNeurons[database]:
            nn_chem_sub.neurons[neuron].set_property('amplitude', jsons[database]['trace_array'][measuredNeurons[database][neuron]][:num_timepoints])
    time_points = np.arange(num_timepoints)#jsons[database]['max_t'])

    ## Inputs
    for inp in input_nodes:
        if hasattr(inp, 'amplitude'):
            input_value = {t:inp.amplitude[t] for t in time_points}
            inputs.append(simulator.TimeSeriesInput([inp], input_value))

    ## Initialize rate model
    rate_model = simulator.RateModel(nn_chem_sub, input_nodes, weights, gains, time_constants, baselines, static_nodes=input_nodes, \
                                        time_points=time_points, inputs=inputs)
    
    node_parameter_bounds =  {'gain': {rn:(-1, 1) for n,rn in rate_model.node_dict.items() if not n in input_nodes}, \
                                'time_constant': {rn:(1, 5) for n,rn in rate_model.node_dict.items() if not n in input_nodes},
                                'baseline': {rn:(-2, 2) for n,rn in rate_model.node_dict.items() if not n in input_nodes}}
    edge_parameter_bounds = {'weight': {e:(-2, 2) for e in rate_model.edges}}
    
    real = {rate_model.node_dict[node]:data['amplitude'] for node,data in nn_chem_sub.nodes(data=True) if 'amplitude' in data}
    vars_to_fit = [rn for rn in real.keys() if not rn in [rate_model.node_dict[n] for n in input_nodes]]
    
    ## Setting parameter bounds for the paramters of interest and set the rest to default to simulate. Use a noisy output to fit.
    o = optimizer.OptunaOptimizer(rate_model, real, optimizer.mean_squared_error, node_parameter_bounds, edge_parameter_bounds, vars_to_fit, num_trials=num_trials)
    best_params, best_model = o.optimize()
    best_fit = best_model.simulate()

    best_models[database] = (best_params, best_model)
    
    plot_rows = [k for k in best_fit.keys() if not str(k.label) in input_neurons and hasattr(nn_chem_sub.neurons[str(k.label)], 'amplitude')]
    f, ax = plt.subplots(figsize=(10,2*len(plot_rows)), nrows=len(plot_rows), sharex=True, layout='constrained')
    # for k, (n, node) in enumerate(nodelist):
    for j,k in enumerate(plot_rows):
        ax[j].plot(nn_chem_sub.neurons[str(k.label)].amplitude, label=f'{k.label}-{nn_chem_sub.neurons[str(k.label)].name}', color='gray')
        ax1 = ax[j]
        ax1.plot(best_fit[k], color='orange')
        utils.simpleaxis(ax[j])
        ax[j].set_title(f'{np.corrcoef(nn_chem_sub.neurons[str(k.label)].amplitude, best_fit[k])[0,1]}')
        ax[j].legend(frameon=False)
    f.suptitle(f'{database}')
    plt.show()

In [None]:
node_parameter_bounds

In [None]:
def update_params(best_params, rate_model, node_params, edge_params):
    node_params_new = copy.deepcopy(node_params)
    edge_params_new = copy.deepcopy(edge_params)
    print(node_params_new, node_params)
    for key in best_params:
        split_key = key.split(':')
        if len(split_key)==2:
            node_params_new[split_key[0]][rate_model.node_dict[split_key[1]]] *=best_params[key]
        elif len(split_key)==4:
            edge_params_new[split_key[0]][rate_model.node_dict[split_key[1], rate_model.node_dict[split_key[2]]]] *= best_params[key]
    return node_params_new, edge_params_new

In [None]:
update_params(best_params=best_params, rate_model=rate_model, node_params=node_parameter_bounds, edge_params=edge_parameter_bounds)

In [None]:
from cedne import simulator
from cedne import optimizer
from cedne import GraphMap
num_trials = 100
best_models = {}
input_nodes = ['1.1']
min_motif = ['1.1', '1.2', '2.1']
num_trials = 50
best_models = {}


for database in jsons.keys():
    nn_chem_sub = nn_chem.subnetwork(connections=all_edges)
    hm = utils.make_hypermotifs(motif, 3, [(3,1)])
    hm = utils.nx.relabel_nodes(hm, {'1.3-2.1':'2.1', '2.3-3.1':'3.1'})
    all_ffgs = nn_chem_sub.search_motifs(hm)
    num_timepoints = len(jsons[database]['trace_array'][measuredNeurons[database][list(measuredNeurons[database].keys())[0]]])
    for neuron in nn_chem_sub.neurons:
        if neuron in measuredNeurons[database]:
            nn_chem_sub.neurons[neuron].set_property('amplitude', jsons[database]['trace_array'][measuredNeurons[database][neuron]])
    
    time_points = np.arange(num_timepoints)#jsons[database]['max_t'])
    ## Inputs
    for inp in input_nodes:
        if hasattr(inp, 'amplitude'):
            input_value = {t:inp.amplitude[t] for t in time_points}
            inputs.append(simulator.TimeSeriesInput([inp], input_value))

    by_motif = {}
    for j,ffg in enumerate(all_ffgs):
        GraphMap(ffg, hm, nn_chem_sub, map_type='edge')
        ## Initialize rate model
        weights = {e:1 for e in hm.edges}
        gains = {node:1.0 for node in hm.nodes}
        baselines = {node:0. for node in hm.nodes}
        time_constants = {n:t for n,t in zip(hm.nodes, tconstants)}
        num_timepoints = len(jsons[database]['trace_array'][measuredNeurons[database][list(measuredNeurons[database].keys())[0]]])

        rate_model = simulator.RateModel(hm, input_nodes, weights, gains, time_constants, baselines, static_nodes=input_nodes, \
                                            time_points=time_points, inputs=inputs)
        
        node_parameter_bounds =  {'gain': {rn:(1, 1) for n,rn in rate_model.node_dict.items() if not n in input_nodes}, \
                                    'time_constant': {rn:(1, 5) for n,rn in rate_model.node_dict.items() if not n in input_nodes},
                                    'baseline': {rn:(0, 1) for n,rn in rate_model.node_dict.items() if not n in input_nodes}}
        edge_parameter_bounds = {'weight': {e:(-1, 1) for e in rate_model.edges}}
        
        
        real = {rate_model.node_dict[node]:data['map'].amplitude for node,data in hm.nodes(data=True) if hasattr(data['map'],'amplitude')}
        
        
        ## Setting parameter bounds for the paramters of interest and set the rest to default to simulate. Use a noisy output to fit.

        for m in range(len(hm.nodes)):
            vars_to_fit = [rn for rn in real.keys() if rn in [rate_model.node_dict[n] for n in sorted(rate_model.node_dict.keys())[1:m]]]
            o = optimizer.OptunaOptimizer(rate_model, real, optimizer.mean_squared_error, node_parameter_bounds, edge_parameter_bounds, vars_to_fit, num_trials=num_trials)
            best_params, best_model = o.optimize()

            node_parameter_bounds = {'gain': {rn:(1, 1) for n,rn in rate_model.node_dict.items() if not n in input_nodes}, \
                                    'time_constant': {rn:(1, 5) for n,rn in rate_model.node_dict.items() if not n in input_nodes},
                                    'baseline': {rn:(0, 1) for n,rn in rate_model.node_dict.items() if not n in input_nodes}} 

        best_fit = best_model.simulate()

        best_models[database] = (best_params, best_model)
        
        nodelist = []
        for edge in sorted(edges):
            if hasattr(nn_chem_sub.neurons[ffg[edge][0].name], 'amplitude') and hasattr(nn_chem_sub.neurons[ffg[edge][1].name], 'amplitude'):
                nodelist+= [(edge[0], ffg[edge][0].name), (edge[1], ffg[edge][1].name)]
        nodelist = sorted(set(nodelist))
        if len(nodelist)>=len(min_motif):
            if all(item in list(zip(*nodelist))[0] for item in min_motif):
                if nn_chem_sub.neurons[nodelist[0][1]].type == 'sensory':
                    f, ax = plt.subplots(figsize=(10,2*len(hm.nodes)), nrows=len(hm.nodes), sharex=True, sharey=True)
                    for k, (edge, node) in enumerate(nodelist):
                        ax[k].plot(nn_chem_sub.neurons[node].amplitude[1000:2500], label=f'{edge}: {node}', color='gray')
                        ax[k].plot(best_fit[node], color='orange')
                        ax[k].legend(frameon=False)
                    utils.simpleaxis(ax)
                    plt.show()
                else:
                    print(nn_chem_sub.neurons[nodelist[0][1]].type, nn_chem_sub.neurons[nodelist[-1][1]].type)

In [None]:
for node, data in hm.nodes(data=True):
    print(node, data['map'])

In [None]:
ffg[('1.1', '1.2')][0]

In [None]:
len(nodelist)