Skip to content

Commit

Permalink
removed a bug where differential equations with a constant right-hand…
Browse files Browse the repository at this point in the history
… side were not properly handled by the automated compute graph optimization
  • Loading branch information
Richert committed Mar 31, 2024
1 parent 01581eb commit 14703aa
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -9,6 +9,8 @@ Changelog

- adjusted the call of the max/min functions: Use `maxi` and `mini` in the equations. Both functions take two input arguments, and return the larger/smaller one, respectively
- updated the PyRates reference in the readme and on the documentation website (using the PLOS CB paper now instead of the arxiv preprint)
- removed a bug where differential equations with a constant right-hand side were not properly handled by the automated compute graph optimization
- resolved an issue with the fortran backend where complex data types were not properly processed during the code generation

1.0.4
-----
Expand Down
6 changes: 3 additions & 3 deletions pyrates/backend/computegraph.py
Expand Up @@ -306,10 +306,10 @@ def eval_nodes(self, nodes: Iterable):

def eval_node(self, n):

inputs = [self.eval_node(inp) for inp in self.predecessors(n)]
inputs = tuple([self.eval_node(inp) for inp in self.predecessors(n)])
node = self.get_var(n)
if isinstance(node, ComputeOp):
return node.func(*tuple(inputs))
return node.func(*inputs)
return node.value

def eval_subgraph(self, n):
Expand Down Expand Up @@ -344,7 +344,7 @@ def compile(self):
self.eval_subgraph(inp)

# evaluate node if all its inputs are constants
if all([self.get_var(inp).is_constant for inp in self.predecessors(node)]):
if all([self.get_var(inp).is_constant for inp in self.predecessors(node)]) and node not in self._eq_nodes:
self.eval_subgraph(node)

# remove unconnected nodes and constants from graph
Expand Down
2 changes: 2 additions & 0 deletions pyrates/backend/fortran/fortran_backend.py
Expand Up @@ -458,6 +458,8 @@ def _get_dtype(self, dtype: Union[str, np.dtype]):
dtype = self._float_precision
if 'float' in dtype:
dtype = 'double precision' if '64' in dtype else 'real'
elif 'complex' in dtype:
dtype = 'complex'
else:
dtype = 'integer'
return dtype
Expand Down

0 comments on commit 14703aa

Please sign in to comment.