Skip to content

Commit

Permalink
debugged midpoint method (should work for operator hierarchies now as…
Browse files Browse the repository at this point in the history
… well)
  • Loading branch information
Richert committed Jul 10, 2019
1 parent 5117700 commit ef7033f
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions pyrates/backend/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,9 +737,7 @@ def parse_equation_system(equations: list, equation_args: dict, backend: tp.Any,

# second rhs evaluation + combination of the two
updates.update({key: arg for key, arg in equation_args.items() if 'inputs' in key})
equations, updates = update_rhs(equations, updates, update_num,
"(var_placeholder + 0.5*update_placeholder)")
equation_args = update_equation_args(equation_args, updates)
equations, _ = update_rhs(equations, updates, update_num, "(var_placeholder + 0.5*update_placeholder)")
_, equation_args = parse_equations(equations=equations, equation_args=equation_args, backend=backend, **kwargs)

elif solver == 'rk23':
Expand All @@ -763,10 +761,9 @@ def parse_equation_system(equations: list, equation_args: dict, backend: tp.Any,

# second rhs evaluation
updates.update({key: arg for key, arg in equation_args.items() if 'inputs' in key})
equations_tmp, updates = update_rhs(deepcopy(equations),
updates, update_num,
equations_tmp, updates = update_rhs(deepcopy(equations), updates, update_num,
"(var_placeholder + 0.5*update_placeholder)")
equation_args = update_equation_args(equation_args, updates)
#equation_args = update_equation_args(equation_args, updates)
state_vars = {key: var for key, var in equation_args.items()
if any([orig_key in key for orig_key in state_vars_orig])}
update_num += 1
Expand All @@ -781,7 +778,7 @@ def parse_equation_system(equations: list, equation_args: dict, backend: tp.Any,
updates.update(updates_new)
equations, updates = update_rhs(equations, updates, update_num,
"(var_placeholder - var_placeholder_1 + 2*update_placeholder)")
equation_args = update_equation_args(equation_args, updates)
#equation_args = update_equation_args(equation_args, updates)
state_vars = {key: var for key, var in equation_args.items()
if any([orig_key in key for orig_key in state_vars_orig])}
update_num += 1
Expand All @@ -792,7 +789,7 @@ def parse_equation_system(equations: list, equation_args: dict, backend: tp.Any,
backend.add_layer()

# combination of 3 rhs evaluations a'la rk23
equation_args = update_equation_args(equation_args, updates)
#equation_args = update_equation_args(equation_args, updates)
equations = []
for var_info in state_vars_orig:
node, op, var = var_info.split('/')
Expand Down Expand Up @@ -848,7 +845,7 @@ def parse_equations(equations: list, equation_args: dict, backend: tp.Any, **kwa
for key, inp in inputs.items():
inp_tmp = equation_args[inp]
op_args[key] = inp_tmp
if type(inp_tmp) is dict:
if type(inp_tmp) is dict and 'vtype' in inp_tmp:
unprocessed_inputs.append(key)

# parse operator variables in backend
Expand Down Expand Up @@ -881,11 +878,12 @@ def parse_equations(equations: list, equation_args: dict, backend: tp.Any, **kwa
#######################

for key, var in parser.vars.items():
if key != "inputs" and key != "rhs" and key != "dt":
if key != "inputs" and key != "rhs" and key != "dt" and key not in inputs:
equation_args[f"{scope}/{key}"] = var

for key, inp in inputs.items():
equation_args[inp] = parser.vars[key]
if key in unprocessed_inputs:
equation_args[inp] = parser.vars[key]

backend.add_layer()

Expand Down

0 comments on commit ef7033f

Please sign in to comment.