Skip to content

Commit

Permalink
Optimise passthrough Nodes out of model
Browse files Browse the repository at this point in the history
Optimises passthrough Nodes out of a model by combining their input and
out signals. Passthrough Nodes can be retained by setting the
`optimize_out` option in the config to False; regardless of this setting
passthrough Nodes associated with Neurons will be removed.
  • Loading branch information
mundya committed Feb 19, 2016
1 parent fbce566 commit 1e9d9c3
Show file tree
Hide file tree
Showing 10 changed files with 1,131 additions and 42 deletions.
16 changes: 7 additions & 9 deletions nengo_spinnaker/builder/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,19 @@ def get_neurons_sink(model, connection):
"""Get the sink for connections into the neurons of an ensemble."""
ens = model.object_operators[connection.post_obj.ensemble]

if isinstance(connection.pre_obj, nengo.ensemble.Neurons):
# Connections from Neurons can go straight to the Neurons
return spec(ObjectPort(ens, EnsembleInputPort.neurons))
elif np.all(connection.transform[1:] == connection.transform[0]):
if (connection.transform.ndim == 2 and
np.all(connection.transform[1:] == connection.transform[0])):
# Connections from non-neurons to Neurons where the transform delivers
# the same value to all neurons are treated as global inhibition
# connection.
# Return a signal to the correct port.
return spec(ObjectPort(ens, EnsembleInputPort.global_inhibition))
else:
# We don't support arbitrary connections into neurons
raise NotImplementedError(
"SpiNNaker does not support arbitrary connections into Neurons. "
"If this is a serious hindrance please open an issue on GitHub."
)
# - Connections from Neurons can go straight to the Neurons.
# - Otherwise we don't support arbitrary connections into neurons, but
# we allow them because they may be optimised out later when we come
# to remove passthrough nodes.
return spec(ObjectPort(ens, EnsembleInputPort.neurons))


ensemble_builders = collections_ext.registerabledict()
Expand Down
26 changes: 18 additions & 8 deletions nengo_spinnaker/builder/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ class NodeIOController(object):
thread which manages IO. This thread must have a `stop` method which
causes the thread to stop executing. See the Ethernet implementation for an
example.
Attributes
----------
host_network : :py:class:`~nengo.Network`
Network containing everything required to perform the host-side of the
simulation.
passthrough_nodes : {Node: operator, ...}
Map of passthrough Nodes to the operators which simulate them on
SpiNNaker, this is exposed so that some passthrough Nodes may be
optimised out.
"""

def __init__(self):
Expand All @@ -64,7 +74,7 @@ def __init__(self):

# Store objects that we've added
self._f_of_t_nodes = dict()
self._passthrough_nodes = dict()
self.passthrough_nodes = dict()
self._added_nodes = set()
self._added_conns = set()
self._input_nodes = dict()
Expand Down Expand Up @@ -113,7 +123,7 @@ def build_node(self, model, node):

op = Filter(node.size_in, n_cores_per_chip=n_cores,
n_chips=n_chips)
self._passthrough_nodes[node] = op
self.passthrough_nodes[node] = op
model.object_operators[node] = op
elif f_of_t:
# If the Node is a function of time then add a new value source for
Expand Down Expand Up @@ -147,18 +157,18 @@ def build_node_probe(self, model, probe):

def get_node_source(self, model, cn):
"""Get the source for a connection originating from a Node."""
if cn.pre_obj in self._passthrough_nodes:
if cn.pre_obj in self.passthrough_nodes:
# If the Node is a passthrough Node then we return a reference
# to the Filter operator we created earlier regardless.
return spec(ObjectPort(self._passthrough_nodes[cn.pre_obj],
return spec(ObjectPort(self.passthrough_nodes[cn.pre_obj],
OutputPort.standard))
elif cn.pre_obj in self._f_of_t_nodes:
# If the Node is a function of time Node then we return a
# reference to the value source we created earlier.
return spec(ObjectPort(self._f_of_t_nodes[cn.pre_obj],
OutputPort.standard))
elif (type(cn.post_obj) is nengo.Node and
cn.post_obj not in self._passthrough_nodes):
cn.post_obj not in self.passthrough_nodes):
# If this connection goes from a Node to another Node (exactly, not
# any subclasses) then we just add both nodes and the connection to
# the host model.
Expand Down Expand Up @@ -199,13 +209,13 @@ def get_spinnaker_source_for_node(self, model, cn): # pragma: no cover

def get_node_sink(self, model, cn):
"""Get the sink for a connection terminating at a Node."""
if cn.post_obj in self._passthrough_nodes:
if cn.post_obj in self.passthrough_nodes:
# If the Node is a passthrough Node then we return a reference
# to the Filter operator we created earlier regardless.
return spec(ObjectPort(self._passthrough_nodes[cn.post_obj],
return spec(ObjectPort(self.passthrough_nodes[cn.post_obj],
InputPort.standard))
elif (type(cn.pre_obj) is nengo.Node and
cn.pre_obj not in self._passthrough_nodes):
cn.pre_obj not in self.passthrough_nodes):
# If this connection goes from a Node to another Node (exactly, not
# any subclasses) then we just add both nodes and the connection to
# the host model.
Expand Down
5 changes: 5 additions & 0 deletions nengo_spinnaker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def add_spinnaker_params(config):
"n_chips",
IntParam(default=None, low=1, optional=True)
)
# Add optimisation control parameters to (passthrough) Nodes. None means
# that a heuristic will be used to determine if the passthrough Node should
# be removed.
config[nengo.Node].set_param("optimize_out",
BoolParam(default=None, optional=True))

# Add profiling parameters to Ensembles
config[nengo.Ensemble].set_param("profile", BoolParam(default=False))
Expand Down
8 changes: 8 additions & 0 deletions nengo_spinnaker/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .node_io import Ethernet
from .rc import rc
from .utils.config import getconfig
from .utils.model import (get_force_removal_passnodes,
optimise_out_passthrough_nodes)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,6 +102,12 @@ def __init__(self, network, dt=0.001, period=10.0, timescale=1.0):
self.model = Model(dt=dt, machine_timestep=machine_timestep,
decoder_cache=get_default_decoder_cache())
self.model.build(network, **builder_kwargs)

forced_removals = get_force_removal_passnodes(network)
optimise_out_passthrough_nodes(self.model,
self.io_controller.passthrough_nodes,
network.config, forced_removals)

logger.info("Build took {:.3f} seconds".format(time.time() -
start_build))

Expand Down
Loading

0 comments on commit 1e9d9c3

Please sign in to comment.