### Sequential Hierarchy simulation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cedne import utils
nx = utils.nx
import cedne.cedne as cdn
import itertools
from scipy.integrate import odeint
from joblib import Parallel, delayed
import matplotlib

In [None]:
def allocate_parameters(input_graph, num_graphs, weight_range, time_constant_range, gain=1):
    """
    Generate a specified number of new graphs with allocated weights to each edge and time constants to each node.

    Parameters:
    input_graph (networkx.Graph): The input graph.
    num_graphs (int): The number of new graphs to generate.
    weight_range (tuple): A tuple containing the minimum and maximum weights for the edges.
    time_constant_range (tuple): A tuple containing the minimum and maximum time constants for the nodes.

    Returns:
    graphs (list): A list of the generated graphs.
    """
    graphs = []
    log_base = 10
    for _ in range(num_graphs):
        # Create a copy of the input graph
        new_graph = input_graph.copy()

        # Allocate weights to the edges
        for edge in new_graph.edges():
            weight = np.random.uniform(weight_range[0], weight_range[1])
            new_graph[edge[0]][edge[1]]['weight'] = weight

        # Allocate time constants to the nodes
        for node in new_graph.nodes():
            time_constant = np.random.uniform(time_constant_range[0], time_constant_range[1])
            new_graph.nodes[node]['time_constant'] = log_base**time_constant
            # Allocate gains to the nodes
            new_graph.nodes[node]['gain'] = gain

        graphs.append(new_graph)

    return graphs

In [None]:
class Input:
        def __init__(self, input_nodes, tstart, tend, value, decay_rate):
            self.input_nodes = input_nodes
            self.value = value
            self.tstart = tstart
            self.tend = tend
            self.decay_rate = decay_rate

        def process_input(self, t):
            return self.value * np.exp(-(t-self.tstart)/self.decay_rate) if t > self.tstart and t<self.tend else 0
        
class RateModel:
    def __init__(self, graph, input_nodes, weights=None, gains=None, time_constants=None):
        """
        Initialize the rate model.

        Args:
            graph (networkx.DiGraph): The graph representing the neural network.
            weights (dict): A dictionary where the keys are the edges and the values are the weights.
            input_nodes (list): A list of nodes that receive external input.
            gains (list): A list of gain terms for each neuron.
            time_constants (list): A list of time constants for each neuron.
        """
        self.graph = graph
        self.weights = {edge: self.graph[edge[0]][edge[1]]['weight'] for edge in self.graph.edges} if weights is None else weights
        self.input_nodes = input_nodes
        self.nodes = list(sorted(graph.nodes))
        self.gains = [self.graph.nodes[node]['gain'] for node in self.nodes] if gains is None else gains

        assert all(['time_constant' in self.graph.nodes[node] for node in self.nodes]) or time_constants is not None, "Each node must have a time constant"
        self.time_constants = [self.graph.nodes[node]['time_constant'] for node in self.nodes] if time_constants is None else time_constants

    def rate_equations(self, rates, t, inputs):
        """
        Compute the derivatives of the rates with respect to time.

        Args:
            rates (list): The current rates of the nodes.
            t (float): The current time.
            inputs (list): The current inputs to the input nodes.

        Returns:
            list: The derivatives of the rates with respect to time.
        """
        derivatives = np.zeros(len(self.nodes))
        for i, node in enumerate(self.nodes):
            derivative = -rates[i] / self.time_constants[i]
            if node in self.input_nodes:
                for input in inputs:
                    derivative += self.gains[i] * input.process_input(t)
            for predecessor in self.graph.predecessors(node):
                j = self.nodes.index(predecessor)
                derivative += self.gains[i] * rates[j] * self.weights[(predecessor, node)]
            derivatives[i] = derivative
        return derivatives

    def simulate(self, time_points, inputs):
        """
        Simulate the rates over time.

        Args:
            initial_rates (list): The initial rates of the nodes.
            time_points (list): The time points at which to simulate the rates.
            inputs (list): The inputs to the input nodes at each time point.

        Returns:
            list: The simulated rates at each time point.
        """
        initial_rates = np.zeros(len(self.nodes))
        simulated_rates = np.zeros((len(time_points), len(self.nodes)))
        simulated_rates[0] = initial_rates
        assert all(type(input) == Input for input in inputs)
        for i in range(1, len(time_points)):
            rates = simulated_rates[i-1]
            t = time_points[i-1]
            derivatives = self.rate_equations(rates, t, inputs)
            simulated_rates[i] = rates + derivatives * (time_points[i] - time_points[i-1])
        return simulated_rates

## Feedforward loop motif

In [None]:
G = nx.DiGraph()
# weights = {(0, 1): 1, (1, 2): -1, (0, 2): 1, (2,3):1, (0, 3): 1}
weights = {(0, 1): -3., (1, 2): -1, (0, 2): -3}
G.add_edges_from(weights.keys())

input_nodes = [0]
gains = [1.0, 1.0, 1.0]
time_constants = [10, 1, 10]
rate_model = RateModel(G, input_nodes, weights, gains, time_constants)

initial_rates = [0., 0., 0.]
gains = [1., 1., 1.]
time_points = np.linspace(0, 90, 451)

## First input
inp1_start = 150
inp1_end = 300
inp1_value= 1
decay_rate_1 = 20 ## Adaptation rate of the receptor

## Second input
inp2_start = 300
inp2_end = 450
inp2_value= 0.2
decay_rate_2 = 20 ## Adaptation rate of the receptor

input1= Input(input_nodes, tstart=time_points[inp1_start], tend=time_points[inp1_end], value=inp1_value, decay_rate=decay_rate_1)
input2 = Input(input_nodes, tstart=time_points[inp2_start], tend=time_points[inp2_end], value=inp2_value, decay_rate=decay_rate_2)

inputs = [input1, input2]

simulated_rates = rate_model.simulate(time_points, inputs)

f, ax = plt.subplots(figsize=(2.5, 2.5), layout='constrained')
ax.plot(time_points, simulated_rates, label=np.arange(len(initial_rates)), lw=2)
ax.axhline(y=0, ls='--', alpha=0.25, color='gray')
for inp in inputs:
    ax.axvline(x=inp.tstart, ls='--', color='gray', alpha=0.25)
    ax.axvline(x=inp.tend, ls='--', color='gray', alpha=0.25)
# ax.set_ylim((-1,1))
ax.set_xlabel('Time (s)')
ax.set_ylabel('Rate')
ax.set_xticks([0,30,60,90])
utils.simpleaxis(ax)
f.legend(loc='outside upper center', ncol=len(simulated_rates), frameon=False)
plt.show()

# f, ax = plt.subplots(figsize=(6,3), layout='constrained')
# ax.plot(time_points, simulated_rates, label=np.arange(len(initial_rates)))
# ax.axhline(y=0, ls='--')
# for inp in inputs:
#     ax.axvline(x=inp.tstart, ls='--')
#     ax.axvline(x=inp.tend, ls='--')
# # ax.set_ylim((-1,1))
# ax.set_xlabel('Time')
# ax.set_ylabel('Rate')
# f.legend(loc='outside center right')
# plt.show()

In [None]:
triad_motifs = utils.return_triads()
motif = triad_motifs['030T']
motif = utils.nx.relabel_nodes(motif, {1:1, 2:3, 3:2})
hm = utils.make_hypermotifs(motif, 3, [(3,1)])
hm = utils.nx.relabel_nodes(hm, {'1.3-2.1':'2.1', '2.3-3.1':'3.1'})

In [None]:
input_nodes = ['1.1']
gains = [1.0]*len(hm.nodes)
maxT = 25
tsteps=1000
initial_rates = [0.]*len(hm.nodes)
time_points = np.linspace(0, maxT, tsteps)
num_iterations = 20

## First input
inp1_start = 10
inp1_end = 20
inp1_value= 1
decay_rate_1 = 20 ## Adaptation rate of the receptor

input1= Input(input_nodes, tstart=time_points[inp1_start], tend=time_points[inp1_end], value=inp1_value, decay_rate=decay_rate_1)

inputs = [input1]


simulated_out = {node: [] for node in hm.nodes}
for j in range(num_iterations):
    weights = {e:np.random.uniform(-1,1,1) for e in hm.edges}
    time_constants = [2**np.random.uniform(-1,1,1) for _ in hm.nodes]
    rate_model = RateModel(hm, input_nodes, weights, gains, time_constants)
    simulated_rates = rate_model.simulate(time_points, inputs)

    for node in hm.nodes:
        simulated_out[node].append(simulated_rates.T[node])

f, ax = plt.subplots(figsize=(14, 2), ncols = len(hm.nodes), layout='constrained', sharex=True)
for k,node in enumerate(sorted(hm.nodes)):
    n1, = ax[k].pcolor(simulated_out[node])#, label=node, lw=0.5, color='k') time_points, 
# n2, = ax1.plot(time_points, simulated_rates.T[-1], label=list(hm.nodes)[-1], lw=1, color='purple', alpha=1)
    ax[k].axhline(y=0, ls='--', alpha=0.25, color='gray')
    for inp in inputs:
        ax[k].axvline(x=inp.tstart, ls='--', color='gray', alpha=0.05)
        ax[k].axvline(x=inp.tend, ls='--', color='gray', alpha=0.2)

# legend_elements = [
#     plt.Line2D([0], [0], color=n1.get_color(), lw=2, label='Input'),
#     plt.Line2D([0], [0], color=n2.get_color(), lw=2, label='Output')
# ]

# # Add the legend
# f.legend(handles=legend_elements, loc='upper center', frameon=False)
# ax.set_ylim((-1,1))
f.supxlabel('Time (s)')
f.supylabel('Rate')
ax[0].set_xticks(np.linspace(0,maxT, 6))
# ax[0].set_ylim((-0.25, 0.25))
# ax1.set_ylim((-0.03, 0.03))
utils.simpleaxis(ax)
plt.show()

## Using simulator module

In [None]:
from cedne import simulator
triads = utils.return_triads()
motif = triads['030T']

motif = utils.nx.relabel_nodes(motif, {1:1, 2:3, 3:2})
chain_length = 1
hm = utils.make_hypermotifs(motif, chain_length, [(3,1)])
# hm = utils.nx.relabel_nodes(hm, {'1.3-2.1':'2.1', '2.3-3.1':'3.1'})

weights = {('1.1', '1.2'): -1, ('1.1', '1.3'): -1., ('1.2', '1.3'): -1}
input_nodes = ['1.1']
gains = {node:0.1 for node in hm.nodes}
baselines = {node:0. for node in hm.nodes}
time_constants = {n:1 for n in hm.nodes}


maxT = 15
tsteps=500
initial_rates = [0.]*len(hm.nodes)
time_points = np.linspace(0, maxT, tsteps)

# First input
inp1_start = 100
inp1_end = 200
inp1_value= 1

input1= simulator.StepInput(input_nodes, tstart=time_points[inp1_start], tend=time_points[inp1_end], value=inp1_value)
inputs = [input1]

rate_model = simulator.RateModel(hm, input_nodes, weights, gains, time_constants, baselines, time_points=time_points, inputs=inputs)
simulated_rates = rate_model.simulate()
f = utils.plot_simulation_results((rate_model, inputs, simulated_rates))
f.savefig('simulated_rates_FFL.svg')

In [None]:
chain_length = 3
hm = utils.make_hypermotifs(motif, chain_length, [(3,1)])
hm = utils.nx.relabel_nodes(hm, {'1.3-2.1':'2.1', '2.3-3.1':'3.1'})

input_nodes = ['1.1']
gains = {node:0.1 for node in hm.nodes}
baselines = {node:0. for node in hm.nodes}
time_constants = {n:1 for n in hm.nodes}
num_iterations = 1000

simulated_out = {node: [] for node in hm.nodes}
for j in range(num_iterations):
    weights = {e:np.random.uniform(-1,1,1) for e in hm.edges}
    # time_constants = {n:2**np.random.uniform(-1,1,1) for n in hm.nodes}
    # time_constants = [2**np.random.uniform(-1,1,1) for _ in hm.nodes]
    rate_model = simulator.RateModel(hm, input_nodes, weights, gains, time_constants, baselines, time_points=time_points, inputs=inputs)
    simulated_rates = rate_model.simulate()
    for node, tcourse in simulated_rates.items():
        simulated_out[node.label].append(tcourse)

In [None]:
10**0.1

In [None]:
# fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 6), gridspec_kw={'height_ratios': [1, 3]})
f, (ax_inp, ax) = plt.subplots(figsize=(7, 2), ncols = len(hm.nodes), nrows=2, layout='constrained', sharex=True, sharey='row', gridspec_kw={'height_ratios': [1, 10]})
sortby_node='2.1'
argsorted_indices = sorted(range(len(simulated_out[sortby_node])), key=lambda i: sum(simulated_out[sortby_node][i]))
# Step function input (binary input over time)
input_times = [0, inp1_start, inp1_end]  # Times when input changes
input_values = [0, 1, 0]  # Step values
for k,node in enumerate(sorted(hm.nodes)):
    # vm = np.max(np.abs(simulated_out[node]))
    vm = 0.01
    n1 = ax[k].pcolormesh(np.array(simulated_out[node])[argsorted_indices], vmin=-vm, vmax=vm, cmap='PuOr', rasterized=True)#, label=node, lw=0.5, color='k') time_points, 
    ax[k].axhline(y=0, ls='--', alpha=0.25, color='gray')
    utils.simpleaxis(ax[k])
    ax[k].set_xticks([0,tsteps/2,tsteps], [0, maxT/2, maxT])
    for inp in inputs:
        ax[k].axvline(x=inp.tstart*(tsteps/maxT), ls='--', color='r', alpha=0.2)
        ax[k].axvline(x=inp.tend*(tsteps/maxT), ls='--', color='r', alpha=0.2)
        if k == 0:
            ax_inp[k].step(input_times, input_values, where='post', color='red', linewidth=2)
        # else:
        #     ax_inp[k].step(input_times, input_values, where='post', color='red', linewidth=2, linestyle = '--')
        
    utils.simpleaxis(ax_inp[k])
    ax_inp[k].set_yticks([])  # Remove y-ticks
    ax_inp[k].set_ylabel("")  # Remove y-label
    ax_inp[k].set_ylim((0,1))  # Remove y-label
    ax_inp[k].spines["left"].set_visible(False)  # Hide the left spine
    ax_inp[k].set_title(node)
f.colorbar(n1, ax=ax[k], label='Rate')
f.supxlabel('Time (s)')
f.supylabel('Model #')
plt.savefig('sim_ffl.svg', format="svg", transparent=True)
plt.show()
plt.close()

In [None]:
f, ax = plt.subplots(figsize=(14, 2), ncols = len(hm.nodes), layout='constrained', sharex=True, sharey=True)
for j,node in enumerate(sorted(simulated_out)):
    ax[j].hist(np.ravel(simulated_out[node]), density=True, cumulative=True, histtype='step', color='gray')
    ax[j].set_xlim((-0.3,0.3))
    ax[j].set_xticks((-0.3,0, 0.3))
    utils.simpleaxis(ax[j])
    ax[j].set_title(node)
plt.show()

In [None]:
utils.nx.draw(hm, with_labels=True, edge_color=[weights[e] for e in hm.edges], edge_cmap=plt.cm.coolwarm, edge_vmin=-1, edge_vmax=1)

In [None]:
node_color= 'white'
edge_colors = ['gray' for edge in hm.edges]
pos = nx.planar_layout(hm)
f, ax = plt.subplots(figsize=(3,3), layout='constrained')
nx.draw_networkx(
hm,
ax=ax,
with_labels=True,
node_color=node_color,
font_size='xx-large',
edge_color = [edge_colors[i] for i,edge in enumerate(hm.edges)],
node_size=400,
arrowsize=20,
width=1,
pos=pos)

nx.draw_networkx_edge_labels(
hm, pos,
edge_labels={edge:"{:.2f}".format(weights[edge]) for edge in hm.edges},
font_color='k', font_size='xx-large', ax=ax
)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
plt.savefig('ffl-network-params.svg', transparent=True)

In [None]:
[weights[edge][0] for edge in hm.edges]