diff --git a/nengo_spinnaker/simulator.py b/nengo_spinnaker/simulator.py index cfddc63..8f50eb1 100644 --- a/nengo_spinnaker/simulator.py +++ b/nengo_spinnaker/simulator.py @@ -104,7 +104,7 @@ def __init__(self, network, dt=0.001, period=10.0, timescale=1.0): self.model.build(network, **builder_kwargs) forced_removals = get_force_removal_passnodes(network) - optimise_out_passthrough_nodes(self.model.connection_map, + optimise_out_passthrough_nodes(self.model, self.io_controller.passthrough_nodes, network.config, forced_removals) diff --git a/nengo_spinnaker/utils/model.py b/nengo_spinnaker/utils/model.py index 9ddbe08..e0f3a92 100644 --- a/nengo_spinnaker/utils/model.py +++ b/nengo_spinnaker/utils/model.py @@ -70,7 +70,7 @@ def find_connected_nodes(node, all_children=None): return force_removal -def optimise_out_passthrough_nodes(conn_map, passthrough_nodes, config, +def optimise_out_passthrough_nodes(model, passthrough_nodes, config, forced_removals=set()): """Remove passthrough Nodes from a network. @@ -89,10 +89,15 @@ def optimise_out_passthrough_nodes(conn_map, passthrough_nodes, config, getconfig(config, node, "optimize_out")) if remove_node or remove_node is None: removed = remove_operator_from_connection_map( - conn_map, operator, force=bool(remove_node)) + model.conn_map, operator, force=bool(remove_node)) # Log if the Node was removed if removed: + if node in model.objects_operators: + model.objects_operators.pop(node) + else: + model.extra_operators.remove(operator) + logger.info("Passthrough Node {!s} was optimized out".format(node)) diff --git a/tests/builder/test_model.py b/tests/builder/test_model.py index 08d99f6..b3c2f46 100644 --- a/tests/builder/test_model.py +++ b/tests/builder/test_model.py @@ -407,6 +407,7 @@ class G(object): cm.add_connection(c, None, model.SignalParameters(), None, e, None, None) cm.add_connection(d, None, model.SignalParameters(), None, e, None, None) cm.add_connection(e, None, model.SignalParameters(), None, f, None, None) + cm._connections[f][None] = list() # Remove the sinkless filters removed = model.remove_sinkless_objects(cm, mock.Mock)