diff --git a/pyrates/backend/computegraph.py b/pyrates/backend/computegraph.py index abc6eafa..0935f995 100755 --- a/pyrates/backend/computegraph.py +++ b/pyrates/backend/computegraph.py @@ -203,6 +203,8 @@ def __init__(self, # edge operators equations, variables_tmp = self._collect_op_layers(layers=[0], exclude=False, op_identifier="edge_from_") variables.update(variables_tmp) + if equations: + self.backend._input_layer_added = True # node operators equations_tmp, variables_tmp = self._collect_op_layers(layers=[], exclude=True, op_identifier="edge_from_") @@ -391,21 +393,7 @@ def run(self, inp_dict[var.name] = np.reshape(val[:, i], (sim_steps,) + tuple(var.shape)) i += 1 - # add inputs to graph - self.backend.add_layer(to_beginning=True) - - # create counting index for input variables - in_idx = self.backend.add_var(vtype='state_var', name='in_var_idx', dtype='int32', shape=(1,), value=0, - scope="network_inputs") - - for key, var in inp_dict.items(): - var_name = f"{var.short_name}_inp" if hasattr(var, 'short_name') else "var_inp" - in_var = self.backend.add_var(vtype='state_var', name=var_name, scope="network_inputs", value=var) - in_var_idx = self.backend.add_op('index', in_var, in_idx, scope="network_inputs") - self.backend.add_op('=', self.backend.vars[key], in_var_idx, scope="network_inputs") - - # create increment operator for counting index - self.backend.add_op('+=', in_idx, np.ones((1,), dtype='int32'), scope="network_inputs") + self.backend.add_input_layer(inputs=inp_dict) # run simulation ################ diff --git a/pyrates/backend/numpy_backend.py b/pyrates/backend/numpy_backend.py index 0589db51..4f915c13 100644 --- a/pyrates/backend/numpy_backend.py +++ b/pyrates/backend/numpy_backend.py @@ -858,6 +858,7 @@ def __init__(self, self._rt_optimization = jit_compile self._base_layer = 0 self._output_layer_added = False + self._input_layer_added = False self._imports = ["import numpy as np", "from pyrates.backend.funcs import *"] if imports: for imp in imports: @@ -1146,23 +1147,48 @@ def add_layer(self, to_beginning=False) -> None: self.layer = len(self.layers) - self._base_layer self.layers.append([]) - def add_output_layer(self, outputs, sampling_steps, out_shapes): + def add_output_layer(self, outputs, sampling_steps, out_shapes) -> dict: + """ - output_col = {} + Parameters + ---------- + outputs + sampling_steps + out_shapes + + Returns + ------- + + """ - # create counting index for collector variables - out_idx = self.add_var(vtype='state_var', name='out_var_idx', dtype='int32', shape=(1,), value=0, - scope="output_collection") + output_col = {} # add output storage layer to the graph if self._output_layer_added: + + out_idx = self.get_var("output_collection/out_var_idx") + + # jump to output layer self.top_layer() + else: + + # add output layer self.add_layer() self._output_layer_added = True + # create counting index for collector variables + out_idx = self.add_var(vtype='state_var', name='out_var_idx', dtype='int32', shape=(1,), value=0, + scope="output_collection") + + # create increment operator for counting index + self.next_layer() + self.add_op('+=', out_idx, np.ones((1,), dtype='int32'), scope="output_collection") + self.previous_layer() + # add collector variables to the graph for i, (var_col) in enumerate(outputs): + shape = (sampling_steps + 1, len(var_col)) + out_shapes[i] key = f"output_col_{i}" output_col[key] = self.add_var(vtype='state_var', name=f"out_col_{i}", scope="output_collection", @@ -1172,11 +1198,39 @@ def add_output_layer(self, outputs, sampling_steps, out_shapes): # add collect operation to the graph self.add_op('=', output_col[key], var_stack, out_idx, scope="output_collection") - # create increment operator for counting index - self.add_op('+=', out_idx, np.ones((1,), dtype='int32'), scope="output_collection") - return output_col + def add_input_layer(self, inputs: dict) -> None: + """ + + Parameters + ---------- + inputs + + Returns + ------- + + """ + + # add inputs to graph + if self._input_layer_added: + self.bottom_layer() + else: + self.add_layer(to_beginning=True) + + # create counting index for input variables + in_idx = self.add_var(vtype='state_var', name='in_var_idx', dtype='int32', shape=(1,), value=0, + scope="network_inputs") + + for key, var in inputs.items(): + var_name = f"{var.short_name}_inp" if hasattr(var, 'short_name') else "var_inp" + in_var = self.add_var(vtype='state_var', name=var_name, scope="network_inputs", value=var) + in_var_idx = self.add_op('index', in_var, in_idx, scope="network_inputs") + self.add_op('=', self.vars[key], in_var_idx, scope="network_inputs") + + # create increment operator for counting index + self.add_op('+=', in_idx, np.ones((1,), dtype='int32'), scope="network_inputs") + def next_layer(self) -> None: """Jump to next layer in stack. If we are already at end of layer stack, add new layer to the stack and jump to that. @@ -1251,16 +1305,24 @@ def clear(self) -> None: self.layer = 0 rmtree(self._build_dir) - def get_var(self, name, updated=True) -> NumpyVar: + def get_layer(self, idx) -> list: + """Retrieve layer from graph. + + Parameters + ---------- + idx + Position of layer in stack. + + """ + return self.layers[self._base_layer + idx] + + def get_var(self, name): """Retrieve variable from graph. Parameters ---------- name Identifier of the variable. - updated - If true, return updated value of a state-variable. Else, return the old, non-updated variable of a - state-variable. Returns ------- @@ -1268,24 +1330,7 @@ def get_var(self, name, updated=True) -> NumpyVar: Variable from graph. """ - if updated: - return self.vars[name] - else: - try: - return self.vars[f'{name}_old'] - except KeyError: - return self.vars[name] - - def get_layer(self, idx) -> list: - """Retrieve layer from graph. - - Parameters - ---------- - idx - Position of layer in stack. - - """ - return self.layers[self._base_layer + idx] + return self.vars[name] def eval_var(self, var) -> np.ndarray: """Get value of variable. diff --git a/pyrates/backend/parser.py b/pyrates/backend/parser.py index 6971cd17..874b8a71 100755 --- a/pyrates/backend/parser.py +++ b/pyrates/backend/parser.py @@ -233,7 +233,7 @@ def parse_expr(self): if self.lhs_key in self.vars and not self._instantaneous: # create new variable for lhs update - name = f"{self.lhs_key}_update" + name = f"{self.lhs_key}_delta" i = 0 name = f"{name}_{i}" while name in self.vars: @@ -728,7 +728,8 @@ def parse_equation_system(equations: list, equation_args: dict, backend: tp.Any, update_num += 1 state_vars = {key: var for key, var in equation_args.items() if (type(var) is dict) and ('vtype' in var) and (var['vtype'] == 'state_var')} - equations_tmp, state_vars = update_lhs(deepcopy(equations), state_vars, update_num, {}) + state_vars_orig = dict(state_vars.items()) + equations_tmp, state_vars = update_lhs(deepcopy(equations), state_vars, update_num, state_vars_orig) equation_args.update(state_vars) updates, equation_args = parse_equations(equations=equations_tmp, equation_args=equation_args, backend=backend, **kwargs) @@ -751,7 +752,7 @@ def parse_equation_system(equations: list, equation_args: dict, backend: tp.Any, state_vars = {key: var for key, var in equation_args.items() if (type(var) is dict) and ('vtype' in var) and (var['vtype'] == 'state_var')} state_vars_orig = dict(state_vars.items()) - equations_tmp, state_vars = update_lhs(deepcopy(equations), state_vars, update_num, {}) + equations_tmp, state_vars = update_lhs(deepcopy(equations), state_vars, update_num, state_vars_orig) for var_orig in state_vars_orig.copy().keys(): if not any([var_orig in var for var in state_vars]): state_vars_orig.pop(var_orig) @@ -1029,7 +1030,7 @@ def update_lhs(equations: list, equation_args: dict, update_num: int, var_dict: equations[i][j] = (f"{lhs} = dt * ({rhs})", scope) if add_to_args: for var_key, var in var_dict.copy().items(): - if var_key in key: + if var_key == key or f"{var_key}_upd_" in key: arg = var break updated_args[f"{node}/{op}/{new_var}"] = arg @@ -1043,21 +1044,29 @@ def update_equation_args(args: dict, updates: dict) -> dict: Parameters ---------- args - udpates + updates Returns ------- """ - args.update(updates) + args_new = {} + + for key, arg in args.items(): + if key in updates: + args_new[key] = updates[key] + else: + args_new[key] = arg + inputs = [key for key in args if 'inputs' in key] for inp in inputs: for in_key, in_map in args[inp].copy().items(): for upd in updates: if in_map in upd: - args[inp].update({upd.split('/')[-1]: upd}) + args_new[inp].update({upd.split('/')[-1]: upd}) break - return args + + return args_new def parse_dict(var_dict: dict, backend, **kwargs) -> dict: diff --git a/pyrates/backend/tensorflow_backend.py b/pyrates/backend/tensorflow_backend.py index f8705e80..544b852f 100644 --- a/pyrates/backend/tensorflow_backend.py +++ b/pyrates/backend/tensorflow_backend.py @@ -235,6 +235,29 @@ def broadcast(self, op1: Any, op2: Any, **kwargs) -> tuple: return super().broadcast(op1, op2, **kwargs) + def get_var(self, name): + """Retrieve variable from graph. + + Parameters + ---------- + name + Identifier of the variable. + + Returns + ------- + NumpyVar + Variable from graph. + + """ + try: + return self.vars[name] + except KeyError as e: + for var in self.vars: + if f"{name}:" in var: + return self.vars[var] + else: + raise e + def stack_vars(self, vars, **kwargs): var_count = {} for var in vars: diff --git a/tests/test_compute_graph.py b/tests/test_compute_graph.py index 5cf15fc0..520a7ca0 100755 --- a/tests/test_compute_graph.py +++ b/tests/test_compute_graph.py @@ -105,7 +105,6 @@ def test_2_1_operator(): # simulate operator behavior results = net.run(sim_time, inputs={'pop0.0/op1.0/u': inp}, outputs={'a': 'pop0.0/op1.0/a'}) - net.clear() # calculate operator behavior from hand update1 = lambda x, y: x + dt*(y-x) @@ -113,8 +112,9 @@ def test_2_1_operator(): for i in range(sim_steps): targets[i+1] = update1(targets[i], inp[i]) - diff = results['a'].values - targets[:-1] + diff = results['a'].values[1:] - targets[:-2] assert np.mean(np.abs(diff)) == pytest.approx(0., rel=1e-6, abs=1e-6) + net.clear() # test correct numerical evaluation of operator with two coupled equations (1 ODE, 1 linear eq.) ################################################################################################ @@ -122,7 +122,6 @@ def test_2_1_operator(): net_config = CircuitTemplate.from_yaml("model_templates.test_resources.test_compute_graph.net2").apply() net = ComputeGraph(net_config=net_config, name='net2', vectorization='none', dt=dt, backend=b) results = net.run(sim_time, outputs={'a': 'pop0.0/op2.0/a'}) - #net.clear() # calculate operator behavior from hand update2 = lambda x: 1./(1. + np.exp(-x)) @@ -133,6 +132,7 @@ def test_2_1_operator(): diff = results['a'].values[:, 0] - targets[:-1, 0] assert np.mean(np.abs(diff)) == pytest.approx(0., rel=1e-6, abs=1e-6) + net.clear() # test correct numerical evaluation of operator with a two coupled DEs ###################################################################### @@ -143,7 +143,6 @@ def test_2_1_operator(): outputs={'b': 'pop0.0/op3.0/b'}, inputs={'pop0.0/op3.0/u': inp}, out_dir="/tmp/log") - net.clear() # calculate operator behavior from hand update3_0 = lambda a, b, u: a + dt*(-10.*a + b**2 + u) @@ -153,8 +152,9 @@ def test_2_1_operator(): targets[i+1, 0] = update3_0(targets[i, 0], targets[i, 1], inp[i]) targets[i+1, 1] = update3_1(targets[i, 1], targets[i, 0]) - diff = results['b'].values[:, 0] - targets[:-1, 1] + diff = results['b'].values[1:, 0] - targets[:-2, 1] assert np.mean(np.abs(diff)) == pytest.approx(0., rel=1e-6, abs=1e-6) + net.clear() def test_2_2_node(): @@ -250,8 +250,8 @@ def test_2_2_node(): targets[i+1, 2] = update3(targets[i, 2], targets[i, 3], targets[i+1, 0]) targets[i+1, 3] = update4(targets[i, 3], targets[i, 2]) - diff = np.mean(np.abs(results['a'].values[:, 0] - targets[:-1, 1])) + \ - np.mean(np.abs(results['b'].values[:, 0] - targets[:-1, 3])) + diff = np.mean(np.abs(results['a'].values[1:, 0] - targets[:-2, 1])) + \ + np.mean(np.abs(results['b'].values[1:, 0] - targets[:-2, 3])) assert diff == pytest.approx(0., rel=1e-6, abs=1e-6) @@ -292,8 +292,8 @@ def test_2_3_edge(): # simulate edge behavior results = net.run(sim_time, outputs={'a': 'pop1.0/op1.0/a', 'b': 'pop2.0/op1.0/a'}) - diff = np.mean(np.abs(results['a']['pop1.0/op1.0'].values - targets[1:, 2])) + \ - np.mean(np.abs(results['b']['pop2.0/op1.0'].values - targets[1:, 3])) + diff = np.mean(np.abs(results['a']['pop1.0/op1.0'].values - targets[:-1, 2])) + \ + np.mean(np.abs(results['b']['pop2.0/op1.0'].values - targets[:-1, 3])) assert diff == pytest.approx(0., rel=1e-6, abs=1e-6) # test correct numerical evaluation of graph with 2 bidirectionaly coupled nodes @@ -315,8 +315,8 @@ def test_2_3_edge(): targets[i + 1, 0] = update2(targets[i, 0], targets[i, 1] * 0.5) targets[i + 1, 1] = update3(targets[i, 1], targets[i, 0] * 2.0, inp[i]) - diff = np.mean(np.abs(results['a']['pop0.0/op1.0'].values - targets[:-1, 0])) + \ - np.mean(np.abs(results['b']['pop1.0/op7.0'].values - targets[:-1, 1])) + diff = np.mean(np.abs(results['a']['pop0.0/op1.0'].values[1:] - targets[:-2, 0])) + \ + np.mean(np.abs(results['b']['pop1.0/op7.0'].values[1:] - targets[:-2, 1])) assert diff == pytest.approx(0., rel=1e-6, abs=1e-6) # test correct numerical evaluation of graph with 2 bidirectionally delay-coupled nodes @@ -338,8 +338,8 @@ def test_2_3_edge(): targets[i + 1, 0] = update4(inp0 * 0.5) targets[i + 1, 1] = update4(inp1 * 2.0) - diff = np.mean(np.abs(results['a']['pop0.0/op8.0'].values - targets[:-1, 0])) + \ - np.mean(np.abs(results['b']['pop1.0/op8.0'].values - targets[:-1, 1])) + diff = np.mean(np.abs(results['a']['pop0.0/op8.0'].values[1:] - targets[:-2, 0])) + \ + np.mean(np.abs(results['b']['pop1.0/op8.0'].values[1:] - targets[:-2, 1])) assert diff == pytest.approx(0., rel=1e-6, abs=1e-6) # test correct numerical evaluation of graph with 2 unidirectionally, multi-delay-coupled nodes @@ -361,8 +361,8 @@ def test_2_3_edge(): targets[i + 1, 0] = update2(targets[i, 0], targets[i, 1] * 0.5) targets[i + 1, 1] = update3(targets[i, 1], targets[i, 0] * 2.0, inp[i]) - diff = np.mean(np.abs(results['a']['pop0.0/op1.0'].values - targets[:-1, 0])) + \ - np.mean(np.abs(results['b']['pop1.0/op7.0'].values - targets[:-1, 1])) + diff = np.mean(np.abs(results['a']['pop0.0/op1.0'].values[1:] - targets[:-2, 0])) + \ + np.mean(np.abs(results['b']['pop1.0/op7.0'].values[1:] - targets[:-2, 1])) assert diff == pytest.approx(0., rel=1e-6, abs=1e-6)