Skip to content

Commit

Permalink
moved input parsing from computegraph to backend classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Richert committed Jul 10, 2019
1 parent d138255 commit 5117700
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 68 deletions.
18 changes: 3 additions & 15 deletions pyrates/backend/computegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def __init__(self,
# edge operators
equations, variables_tmp = self._collect_op_layers(layers=[0], exclude=False, op_identifier="edge_from_")
variables.update(variables_tmp)
if equations:
self.backend._input_layer_added = True

# node operators
equations_tmp, variables_tmp = self._collect_op_layers(layers=[], exclude=True, op_identifier="edge_from_")
Expand Down Expand Up @@ -391,21 +393,7 @@ def run(self,
inp_dict[var.name] = np.reshape(val[:, i], (sim_steps,) + tuple(var.shape))
i += 1

# add inputs to graph
self.backend.add_layer(to_beginning=True)

# create counting index for input variables
in_idx = self.backend.add_var(vtype='state_var', name='in_var_idx', dtype='int32', shape=(1,), value=0,
scope="network_inputs")

for key, var in inp_dict.items():
var_name = f"{var.short_name}_inp" if hasattr(var, 'short_name') else "var_inp"
in_var = self.backend.add_var(vtype='state_var', name=var_name, scope="network_inputs", value=var)
in_var_idx = self.backend.add_op('index', in_var, in_idx, scope="network_inputs")
self.backend.add_op('=', self.backend.vars[key], in_var_idx, scope="network_inputs")

# create increment operator for counting index
self.backend.add_op('+=', in_idx, np.ones((1,), dtype='int32'), scope="network_inputs")
self.backend.add_input_layer(inputs=inp_dict)

# run simulation
################
Expand Down
105 changes: 75 additions & 30 deletions pyrates/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ def __init__(self,
self._rt_optimization = jit_compile
self._base_layer = 0
self._output_layer_added = False
self._input_layer_added = False
self._imports = ["import numpy as np", "from pyrates.backend.funcs import *"]
if imports:
for imp in imports:
Expand Down Expand Up @@ -1146,23 +1147,48 @@ def add_layer(self, to_beginning=False) -> None:
self.layer = len(self.layers) - self._base_layer
self.layers.append([])

def add_output_layer(self, outputs, sampling_steps, out_shapes):
def add_output_layer(self, outputs, sampling_steps, out_shapes) -> dict:
"""
output_col = {}
Parameters
----------
outputs
sampling_steps
out_shapes
Returns
-------
"""

# create counting index for collector variables
out_idx = self.add_var(vtype='state_var', name='out_var_idx', dtype='int32', shape=(1,), value=0,
scope="output_collection")
output_col = {}

# add output storage layer to the graph
if self._output_layer_added:

out_idx = self.get_var("output_collection/out_var_idx")

# jump to output layer
self.top_layer()

else:

# add output layer
self.add_layer()
self._output_layer_added = True

# create counting index for collector variables
out_idx = self.add_var(vtype='state_var', name='out_var_idx', dtype='int32', shape=(1,), value=0,
scope="output_collection")

# create increment operator for counting index
self.next_layer()
self.add_op('+=', out_idx, np.ones((1,), dtype='int32'), scope="output_collection")
self.previous_layer()

# add collector variables to the graph
for i, (var_col) in enumerate(outputs):

shape = (sampling_steps + 1, len(var_col)) + out_shapes[i]
key = f"output_col_{i}"
output_col[key] = self.add_var(vtype='state_var', name=f"out_col_{i}", scope="output_collection",
Expand All @@ -1172,11 +1198,39 @@ def add_output_layer(self, outputs, sampling_steps, out_shapes):
# add collect operation to the graph
self.add_op('=', output_col[key], var_stack, out_idx, scope="output_collection")

# create increment operator for counting index
self.add_op('+=', out_idx, np.ones((1,), dtype='int32'), scope="output_collection")

return output_col

def add_input_layer(self, inputs: dict) -> None:
"""
Parameters
----------
inputs
Returns
-------
"""

# add inputs to graph
if self._input_layer_added:
self.bottom_layer()
else:
self.add_layer(to_beginning=True)

# create counting index for input variables
in_idx = self.add_var(vtype='state_var', name='in_var_idx', dtype='int32', shape=(1,), value=0,
scope="network_inputs")

for key, var in inputs.items():
var_name = f"{var.short_name}_inp" if hasattr(var, 'short_name') else "var_inp"
in_var = self.add_var(vtype='state_var', name=var_name, scope="network_inputs", value=var)
in_var_idx = self.add_op('index', in_var, in_idx, scope="network_inputs")
self.add_op('=', self.vars[key], in_var_idx, scope="network_inputs")

# create increment operator for counting index
self.add_op('+=', in_idx, np.ones((1,), dtype='int32'), scope="network_inputs")

def next_layer(self) -> None:
"""Jump to next layer in stack. If we are already at end of layer stack, add new layer to the stack and jump to
that.
Expand Down Expand Up @@ -1251,41 +1305,32 @@ def clear(self) -> None:
self.layer = 0
rmtree(self._build_dir)

def get_var(self, name, updated=True) -> NumpyVar:
def get_layer(self, idx) -> list:
"""Retrieve layer from graph.
Parameters
----------
idx
Position of layer in stack.
"""
return self.layers[self._base_layer + idx]

def get_var(self, name):
"""Retrieve variable from graph.
Parameters
----------
name
Identifier of the variable.
updated
If true, return updated value of a state-variable. Else, return the old, non-updated variable of a
state-variable.
Returns
-------
NumpyVar
Variable from graph.
"""
if updated:
return self.vars[name]
else:
try:
return self.vars[f'{name}_old']
except KeyError:
return self.vars[name]

def get_layer(self, idx) -> list:
"""Retrieve layer from graph.
Parameters
----------
idx
Position of layer in stack.
"""
return self.layers[self._base_layer + idx]
return self.vars[name]

def eval_var(self, var) -> np.ndarray:
"""Get value of variable.
Expand Down
25 changes: 17 additions & 8 deletions pyrates/backend/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def parse_expr(self):
if self.lhs_key in self.vars and not self._instantaneous:

# create new variable for lhs update
name = f"{self.lhs_key}_update"
name = f"{self.lhs_key}_delta"
i = 0
name = f"{name}_{i}"
while name in self.vars:
Expand Down Expand Up @@ -728,7 +728,8 @@ def parse_equation_system(equations: list, equation_args: dict, backend: tp.Any,
update_num += 1
state_vars = {key: var for key, var in equation_args.items()
if (type(var) is dict) and ('vtype' in var) and (var['vtype'] == 'state_var')}
equations_tmp, state_vars = update_lhs(deepcopy(equations), state_vars, update_num, {})
state_vars_orig = dict(state_vars.items())
equations_tmp, state_vars = update_lhs(deepcopy(equations), state_vars, update_num, state_vars_orig)
equation_args.update(state_vars)
updates, equation_args = parse_equations(equations=equations_tmp, equation_args=equation_args, backend=backend,
**kwargs)
Expand All @@ -751,7 +752,7 @@ def parse_equation_system(equations: list, equation_args: dict, backend: tp.Any,
state_vars = {key: var for key, var in equation_args.items()
if (type(var) is dict) and ('vtype' in var) and (var['vtype'] == 'state_var')}
state_vars_orig = dict(state_vars.items())
equations_tmp, state_vars = update_lhs(deepcopy(equations), state_vars, update_num, {})
equations_tmp, state_vars = update_lhs(deepcopy(equations), state_vars, update_num, state_vars_orig)
for var_orig in state_vars_orig.copy().keys():
if not any([var_orig in var for var in state_vars]):
state_vars_orig.pop(var_orig)
Expand Down Expand Up @@ -1029,7 +1030,7 @@ def update_lhs(equations: list, equation_args: dict, update_num: int, var_dict:
equations[i][j] = (f"{lhs} = dt * ({rhs})", scope)
if add_to_args:
for var_key, var in var_dict.copy().items():
if var_key in key:
if var_key == key or f"{var_key}_upd_" in key:
arg = var
break
updated_args[f"{node}/{op}/{new_var}"] = arg
Expand All @@ -1043,21 +1044,29 @@ def update_equation_args(args: dict, updates: dict) -> dict:
Parameters
----------
args
udpates
updates
Returns
-------
"""
args.update(updates)
args_new = {}

for key, arg in args.items():
if key in updates:
args_new[key] = updates[key]
else:
args_new[key] = arg

inputs = [key for key in args if 'inputs' in key]
for inp in inputs:
for in_key, in_map in args[inp].copy().items():
for upd in updates:
if in_map in upd:
args[inp].update({upd.split('/')[-1]: upd})
args_new[inp].update({upd.split('/')[-1]: upd})
break
return args

return args_new


def parse_dict(var_dict: dict, backend, **kwargs) -> dict:
Expand Down
23 changes: 23 additions & 0 deletions pyrates/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,29 @@ def broadcast(self, op1: Any, op2: Any, **kwargs) -> tuple:

return super().broadcast(op1, op2, **kwargs)

def get_var(self, name):
"""Retrieve variable from graph.
Parameters
----------
name
Identifier of the variable.
Returns
-------
NumpyVar
Variable from graph.
"""
try:
return self.vars[name]
except KeyError as e:
for var in self.vars:
if f"{name}:" in var:
return self.vars[var]
else:
raise e

def stack_vars(self, vars, **kwargs):
var_count = {}
for var in vars:
Expand Down

0 comments on commit 5117700

Please sign in to comment.