From f6a3da1d7122913d3a376f45333791a8fe7b1cdd Mon Sep 17 00:00:00 2001 From: Andrew Mundy Date: Tue, 20 Oct 2015 15:36:36 +0100 Subject: [PATCH] Optimise passthrough Nodes out of model Optimises passthrough Nodes out of a model by combining their input and out signals. Passthrough Nodes can be retained by setting the `optimize_out` option in the config to False; regardless of this setting passthrough Nodes associated with Neurons will be removed. --- nengo_spinnaker/builder/ensemble.py | 16 +- nengo_spinnaker/builder/node.py | 26 +- nengo_spinnaker/config.py | 5 + nengo_spinnaker/simulator.py | 8 + nengo_spinnaker/utils/model.py | 385 +++++++++++++ regression-tests/test_global_inhibition.py | 48 ++ tests/builder/test_ensemble.py | 33 +- tests/builder/test_model.py | 1 + tests/test_config.py | 12 +- tests/utils/test_utils_model.py | 639 +++++++++++++++++++++ 10 files changed, 1131 insertions(+), 42 deletions(-) create mode 100644 nengo_spinnaker/utils/model.py create mode 100644 tests/utils/test_utils_model.py diff --git a/nengo_spinnaker/builder/ensemble.py b/nengo_spinnaker/builder/ensemble.py index 2edfe46..15b57ee 100644 --- a/nengo_spinnaker/builder/ensemble.py +++ b/nengo_spinnaker/builder/ensemble.py @@ -59,21 +59,19 @@ def get_neurons_sink(model, connection): """Get the sink for connections into the neurons of an ensemble.""" ens = model.object_operators[connection.post_obj.ensemble] - if isinstance(connection.pre_obj, nengo.ensemble.Neurons): - # Connections from Neurons can go straight to the Neurons - return spec(ObjectPort(ens, EnsembleInputPort.neurons)) - elif np.all(connection.transform[1:] == connection.transform[0]): + if (connection.transform.ndim == 2 and + np.all(connection.transform[1:] == connection.transform[0])): # Connections from non-neurons to Neurons where the transform delivers # the same value to all neurons are treated as global inhibition # connection. # Return a signal to the correct port. return spec(ObjectPort(ens, EnsembleInputPort.global_inhibition)) else: - # We don't support arbitrary connections into neurons - raise NotImplementedError( - "SpiNNaker does not support arbitrary connections into Neurons. " - "If this is a serious hindrance please open an issue on GitHub." - ) + # - Connections from Neurons can go straight to the Neurons. + # - Otherwise we don't support arbitrary connections into neurons, but + # we allow them because they may be optimised out later when we come + # to remove passthrough nodes. + return spec(ObjectPort(ens, EnsembleInputPort.neurons)) ensemble_builders = collections_ext.registerabledict() diff --git a/nengo_spinnaker/builder/node.py b/nengo_spinnaker/builder/node.py index 69b5489..44e860b 100644 --- a/nengo_spinnaker/builder/node.py +++ b/nengo_spinnaker/builder/node.py @@ -55,6 +55,16 @@ class NodeIOController(object): thread which manages IO. This thread must have a `stop` method which causes the thread to stop executing. See the Ethernet implementation for an example. + + Attributes + ---------- + host_network : :py:class:`~nengo.Network` + Network containing everything required to perform the host-side of the + simulation. + passthrough_nodes : {Node: operator, ...} + Map of passthrough Nodes to the operators which simulate them on + SpiNNaker, this is exposed so that some passthrough Nodes may be + optimised out. """ def __init__(self): @@ -64,7 +74,7 @@ def __init__(self): # Store objects that we've added self._f_of_t_nodes = dict() - self._passthrough_nodes = dict() + self.passthrough_nodes = dict() self._added_nodes = set() self._added_conns = set() self._input_nodes = dict() @@ -113,7 +123,7 @@ def build_node(self, model, node): op = Filter(node.size_in, n_cores_per_chip=n_cores, n_chips=n_chips) - self._passthrough_nodes[node] = op + self.passthrough_nodes[node] = op model.object_operators[node] = op elif f_of_t: # If the Node is a function of time then add a new value source for @@ -147,10 +157,10 @@ def build_node_probe(self, model, probe): def get_node_source(self, model, cn): """Get the source for a connection originating from a Node.""" - if cn.pre_obj in self._passthrough_nodes: + if cn.pre_obj in self.passthrough_nodes: # If the Node is a passthrough Node then we return a reference # to the Filter operator we created earlier regardless. - return spec(ObjectPort(self._passthrough_nodes[cn.pre_obj], + return spec(ObjectPort(self.passthrough_nodes[cn.pre_obj], OutputPort.standard)) elif cn.pre_obj in self._f_of_t_nodes: # If the Node is a function of time Node then we return a @@ -158,7 +168,7 @@ def get_node_source(self, model, cn): return spec(ObjectPort(self._f_of_t_nodes[cn.pre_obj], OutputPort.standard)) elif (type(cn.post_obj) is nengo.Node and - cn.post_obj not in self._passthrough_nodes): + cn.post_obj not in self.passthrough_nodes): # If this connection goes from a Node to another Node (exactly, not # any subclasses) then we just add both nodes and the connection to # the host model. @@ -199,13 +209,13 @@ def get_spinnaker_source_for_node(self, model, cn): # pragma: no cover def get_node_sink(self, model, cn): """Get the sink for a connection terminating at a Node.""" - if cn.post_obj in self._passthrough_nodes: + if cn.post_obj in self.passthrough_nodes: # If the Node is a passthrough Node then we return a reference # to the Filter operator we created earlier regardless. - return spec(ObjectPort(self._passthrough_nodes[cn.post_obj], + return spec(ObjectPort(self.passthrough_nodes[cn.post_obj], InputPort.standard)) elif (type(cn.pre_obj) is nengo.Node and - cn.pre_obj not in self._passthrough_nodes): + cn.pre_obj not in self.passthrough_nodes): # If this connection goes from a Node to another Node (exactly, not # any subclasses) then we just add both nodes and the connection to # the host model. diff --git a/nengo_spinnaker/config.py b/nengo_spinnaker/config.py index bdf7119..760ca93 100644 --- a/nengo_spinnaker/config.py +++ b/nengo_spinnaker/config.py @@ -40,6 +40,11 @@ def add_spinnaker_params(config): "n_chips", IntParam(default=None, low=1, optional=True) ) + # Add optimisation control parameters to (passthrough) Nodes. None means + # that a heuristic will be used to determine if the passthrough Node should + # be removed. + config[nengo.Node].set_param("optimize_out", + BoolParam(default=None, optional=True)) # Add profiling parameters to Ensembles config[nengo.Ensemble].set_param("profile", BoolParam(default=False)) diff --git a/nengo_spinnaker/simulator.py b/nengo_spinnaker/simulator.py index f117443..aca2b2b 100644 --- a/nengo_spinnaker/simulator.py +++ b/nengo_spinnaker/simulator.py @@ -14,6 +14,8 @@ from .node_io import Ethernet from .rc import rc from .utils.config import getconfig +from .utils.model import (get_force_removal_passnodes, + optimise_out_passthrough_nodes) logger = logging.getLogger(__name__) @@ -100,6 +102,12 @@ def __init__(self, network, dt=0.001, period=10.0, timescale=1.0): self.model = Model(dt=dt, machine_timestep=machine_timestep, decoder_cache=get_default_decoder_cache()) self.model.build(network, **builder_kwargs) + + forced_removals = get_force_removal_passnodes(network) + optimise_out_passthrough_nodes(self.model, + self.io_controller.passthrough_nodes, + network.config, forced_removals) + logger.info("Build took {:.3f} seconds".format(time.time() - start_build)) diff --git a/nengo_spinnaker/utils/model.py b/nengo_spinnaker/utils/model.py new file mode 100644 index 0000000..00f9e40 --- /dev/null +++ b/nengo_spinnaker/utils/model.py @@ -0,0 +1,385 @@ +from __future__ import absolute_import + +import collections +import logging +import nengo.synapses +import numpy as np +from six.moves import filter as sfilter +from six import iteritems, itervalues + +from nengo_spinnaker.builder.ensemble import EnsembleTransmissionParameters +from nengo_spinnaker.builder.model import ( + ReceptionParameters, SignalParameters +) +from nengo_spinnaker.builder.node import ( + PassthroughNodeTransmissionParameters, NodeTransmissionParameters +) +from nengo_spinnaker.builder.ports import EnsembleInputPort +from nengo_spinnaker.utils.config import getconfig + + +logger = logging.getLogger(__name__) + + +def get_force_removal_passnodes(network): + """Get a set of which passthrough Nodes should be forcibly removed + regardless of the configuration options, currently this is any passthrough + Nodes which are directly or indirectly connected (through other passthrough + Nodes) to Neurons. + """ + def is_ptn(n): + return isinstance(n, nengo.Node) and n.output is None + + # Build a dictionary representing which objects are connected to which + # other objects (undirected graph). + all_io = collections.defaultdict(list) + for conn in network.all_connections: + all_io[conn.pre_obj].append(conn.post_obj) + all_io[conn.post_obj].append(conn.pre_obj) + + # Find passthrough Nodes which are directly connected to neurons and mark + # them for removal. + force_removal = {n for n in network.all_nodes if is_ptn(n) and + any(isinstance(o, nengo.ensemble.Neurons) + for o in all_io[n])} + + # For each passthrough Node perform a search to determine which other + # passthrough Nodes it is connected to. + def find_connected_nodes(node, all_children=None): + all_children = all_children or set() # Ensure we have a set + all_children.add(node) # Mark ourselves as visited + + # Recursively get all child nodes, updating `all_children` in place + for n in sfilter(lambda m: is_ptn(m) and m not in all_children, + all_io[node]): + find_connected_nodes(n, all_children) + + return all_children + + # Find all the passthrough Nodes that connect directly or indirectly to + # neurons. + remaining_nodes = set(force_removal) + while remaining_nodes: + # Get a start node then mark all connected Nodes as requiring removal + # and remove them from the list of remaining Nodes. + start_node = next(iter(remaining_nodes)) + nodes = find_connected_nodes(start_node) + force_removal.update(nodes) + remaining_nodes.difference_update(nodes) + + return force_removal + + +def optimise_out_passthrough_nodes(model, passthrough_nodes, config, + forced_removals=set()): + """Remove passthrough Nodes from a network. + + Other Parameters + ---------------- + forced_removals : {Node, ...} + Set of Nodes which should be removed regardless of the configuration + settings. + """ + for node, operator in iteritems(passthrough_nodes): + removed = False + + # Determine whether to remove the Node or not (if True then definitely + # remove, if None then remove if it doesn't worsen network usage). + remove_node = (node in forced_removals or + getconfig(config, node, "optimize_out")) + if remove_node or remove_node is None: + removed = remove_operator_from_connection_map( + model.connection_map, operator, force=bool(remove_node)) + + # Log if the Node was removed + if removed: + if node in model.object_operators: + model.object_operators.pop(node) + else: + model.extra_operators.remove(operator) + + logger.info("Passthrough Node {!s} was optimized out".format(node)) + + +def remove_operator_from_connection_map(conn_map, target, force=True): + """Remove an operator from a connection map by combining the connections + that lead to and from the operator. + + Parameters + ---------- + conn_map : `ConnectionMap` + Connection map from which the operator will be removed. Note that the + connection map will be modified. + target : object + Operator to remove from the connection map. + + Other Parameters + ---------------- + force : bool + If False then the operator will only be optimised out if it is expected + that doing so will break network performance. If True (the default) + then the operator will be removed regardless. + """ + # Grab the old connection map + old_conns = conn_map._connections.copy() + saved_conns = conn_map._connections.copy() + + # Grab all the connections which are transmitted by the operator. + if target not in old_conns: + return True + out_conns = list(_get_port_kwargs(old_conns.pop(target))) + + # Compute the most packets any object which receives packets from the + # operator will receive if the operator is not removed. + old_max_packets = _get_max_packets_received(out_conns) + + # Prepare to compute the new equivalent of this value + new_rx = collections.defaultdict(lambda: 0) + + # Create a new empty connection map dictionary and update the connection + # map to use the new dictionary + conns = collections.defaultdict(lambda: collections.defaultdict(list)) + conn_map._connections = conns + + # Copy across all connections from the old connection map; every time we + # encounter the operator as the sink of a signal we multiply that signal by + # each of the outgoing connections in turn and add those connections to the + # new connection map instead. + for kwargs in _iter_connection_map_as_kwargs(old_conns): + # The target will never appear as a source object as we removed it as a + # key before. Therefore determine what to do by looking at the sink + # object. + if kwargs["sink_object"] is target: + # If the sink object is the target then create multiply this + # connection with each of the outgoing signals for the target. + for new_kwargs in _multiply_signals(kwargs, out_conns): + conn_map.add_connection(**new_kwargs) + + # Update the new number of packets that will be received by + # each sink (assuming that zeroed rows will be removed from the + # transform) + weight = np.sum(np.any( + new_kwargs["transmission_parameters"].transform != 0.0, + axis=1) + ) + new_rx[new_kwargs["sink_object"]] += weight + else: + # Otherwise (re-)add this connection to the connection map. + conn_map.add_connection(**kwargs) + + # Determine whether the changes made should be discarded. + new_max_packets = max(itervalues(new_rx)) + discard_changes = new_max_packets > old_max_packets + + # If not forced and we caused a worsening in network usage then copy the + # old connection map back in and discard our changes. + if not force and discard_changes: + conn_map._connections = saved_conns + return False + else: + return True + + +def _iter_connection_map_as_kwargs(conn_dict): + """Iterate over a connection map and yield keywords for every connection. + + Yields + ------ + dict + Keyword arguments for `ConnectionMap.add_connection` + """ + # Iterate through the dictionary + for source_object, ports_and_signals in iteritems(conn_dict): + for kwargs in _get_port_kwargs(ports_and_signals): + # Add the source object to the kwargs and yield + kwargs["source_object"] = source_object + yield kwargs + + +def _get_port_kwargs(ports_and_signals): + """ + Yields + ------ + dict + Keyword arguments for `ConnectionMap.add_connection` + """ + # Iterate through the dictionary + for source_port, signals_and_sinks in iteritems(ports_and_signals): + for signal_and_sink in signals_and_sinks: + for (sink_object, sink_port, + reception_parameters) in signal_and_sink.sinks: + # Break out the signal and transmission parameters + signal_parameters, transmission_parameters = \ + signal_and_sink.parameters + + # Yield the keyword arguments for this connection + yield { + "source_port": source_port, + "signal_parameters": signal_parameters, + "transmission_parameters": transmission_parameters, + "sink_object": sink_object, + "sink_port": sink_port, + "reception_parameters": reception_parameters, + } + + +def _get_max_packets_received(port_kwargs): + """Get the maximum number of packets received by any single object in the + specified port kwargs. + """ + rx = collections.defaultdict(lambda: 0) + for kwargs in port_kwargs: + rx[kwargs["sink_object"]] += kwargs["signal_parameters"].weight + + return max(itervalues(rx)) + + +def _multiply_signals(in_kwargs, out_conn_kwargs): + """Multiply an input connection by a selection of outgoing connection and + yield keywords for every new connection. + + Yields + ------ + dict + Keyword arguments for `ConnectionMap.add_connection` + """ + # For every outgoing connection + for out_conn in out_conn_kwargs: + # Combine the transmission parameters + transmission_parameters, sink_port = _combine_transmission_params( + in_kwargs["transmission_parameters"], + out_conn["transmission_parameters"], + out_conn["sink_port"] + ) + + # If the connection has been optimised out then move on + if transmission_parameters is None and sink_port is None: + continue + + # Combine the reception parameters + reception_parameters = _combine_reception_params( + in_kwargs["reception_parameters"], + out_conn["reception_parameters"], + ) + + # Combine the signal parameters: the new signal will be latching if + # either the input or the output signals require it be so, it will have + # the weight assigned by the reception parameters. If either the input + # or output keyspace are None then the keyspace assigned to the other + # signal will be used; if neither are None then we break because + # there's no clear way to merge keyspaces. + in_sig_pars = in_kwargs["signal_parameters"] + out_sig_pars = out_conn["signal_parameters"] + + latching = in_sig_pars.latching or out_sig_pars.latching + weight = out_sig_pars.weight + + if in_sig_pars.keyspace is None or out_sig_pars.keyspace is None: + keyspace = in_sig_pars.keyspace or out_sig_pars.keyspace + else: + raise NotImplementedError("Cannot merge keyspaces") + + # Construct the new signal parameters + signal_parameters = SignalParameters(latching, weight, keyspace) + + # Yield the new keyword arguments + yield { + "source_object": in_kwargs["source_object"], + "source_port": in_kwargs["source_port"], + "signal_parameters": signal_parameters, + "transmission_parameters": transmission_parameters, + "sink_object": out_conn["sink_object"], + "sink_port": sink_port, + "reception_parameters": reception_parameters, + } + + +def _combine_transmission_params(in_transmission_parameters, + out_transmission_parameters, + final_port): + """Combine transmission parameters to join two signals into one, e.g., for + optimising out a passthrough Node. + + Returns + ------- + transmission_parameters + New transmission parameters + port + New receiving port for the connection + """ + assert isinstance(out_transmission_parameters, + PassthroughNodeTransmissionParameters) + + # Compute the new transform + new_transform = np.dot(out_transmission_parameters.transform, + in_transmission_parameters.transform) + + # If the resultant transform is empty then we return None to indicate that + # the connection should be dropped. + if np.all(new_transform == 0.0): + return None, None + + # If the connection is a global inhibition connection then truncate the + # transform and modify the final port to reroute the connection. + if (final_port is EnsembleInputPort.neurons and + np.all(new_transform[0] == new_transform[1:])): + # Truncate the transform + new_transform = new_transform[0] + new_transform.shape = (1, -1) # Ensure the result is a matrix + + # Change the final port + final_port = EnsembleInputPort.global_inhibition + + # Construct the new transmission parameters + if isinstance(in_transmission_parameters, + EnsembleTransmissionParameters): + transmission_params = EnsembleTransmissionParameters( + in_transmission_parameters.untransformed_decoders, + new_transform + ) + elif isinstance(in_transmission_parameters, + NodeTransmissionParameters): + transmission_params = NodeTransmissionParameters( + in_transmission_parameters.pre_slice, + in_transmission_parameters.function, + new_transform + ) + elif isinstance(in_transmission_parameters, + PassthroughNodeTransmissionParameters): + transmission_params = PassthroughNodeTransmissionParameters( + new_transform + ) + else: + raise NotImplementedError + + return transmission_params, final_port + + +def _combine_reception_params(in_reception_parameters, + out_reception_parameters): + """Combine reception parameters to join two signals into one, e.g., for + optimising out a passthrough Node. + """ + # Construct the new reception parameters + # Combine the filters + filter_in = in_reception_parameters.filter + filter_out = out_reception_parameters.filter + + if (filter_in is None or filter_out is None): + # If either filter is None then just use the filter from the other + # connection + new_filter = filter_in or filter_out + elif (isinstance(filter_in, nengo.LinearFilter) and + isinstance(filter_out, nengo.LinearFilter)): + # Both filters are linear filters, so multiply the numerators and + # denominators together to get a new linear filter. + new_num = np.polymul(filter_in.num, filter_out.num) + new_den = np.polymul(filter_in.den, filter_out.den) + + new_filter = nengo.LinearFilter(new_num, new_den) + else: + raise NotImplementedError + + # Take the size in from the second reception parameter, construct the new + # reception parameters. + return ReceptionParameters(new_filter, out_reception_parameters.width) diff --git a/regression-tests/test_global_inhibition.py b/regression-tests/test_global_inhibition.py index 4097d50..10f0b50 100644 --- a/regression-tests/test_global_inhibition.py +++ b/regression-tests/test_global_inhibition.py @@ -57,7 +57,55 @@ def test_global_inhibition(source): np.all(-0.05 <= data[index20:, 1])) +def test_global_inhibition_from_optimised_out_passthrough_node(): + with nengo.Network("Test Network") as network: + # Create a 2-d ensemble representing a constant value + input_value = nengo.Node([0.25, -0.3]) + ens = nengo.Ensemble(200, 2) + nengo.Connection(input_value, ens) + p_ens = nengo.Probe(ens, synapse=0.05) + + # The gate should be open initially and then closed after 1s + gate_control = nengo.Node(lambda t: 0.0 if t < 1.0 else 1.0) + + # Add the passthrough Node + gate_ptn = nengo.Node(size_in=ens.n_neurons) + nengo.Connection(gate_ptn, ens.neurons) + + # Create the control and + nengo.Connection(gate_control, gate_ptn, + transform=[[-10.0]] * ens.n_neurons) + + # Mark appropriate Nodes as functions of time + nengo_spinnaker.add_spinnaker_params(network.config) + network.config[gate_control].function_of_time = True + + # Create the simulate and simulate + sim = nengo_spinnaker.Simulator(network) + + # Run the simulation for long enough to ensure that the decoded value is + # with +/-20% of the input value. + with sim: + sim.run(2.0) + + # Check that the values are decoded as expected + index10 = int(p_ens.synapse.tau * 4 / sim.dt) + index11 = 1.0 / sim.dt + index20 = index11 + int(p_ens.synapse.tau * 4 / sim.dt) + data = sim.data[p_ens] + + assert (np.all(+0.20 <= data[index10:index11, 0]) and + np.all(+0.30 >= data[index10:index11, 0]) and + np.all(+0.05 >= data[index20:, 0]) and + np.all(-0.05 <= data[index20:, 0])) + assert (np.all(-0.36 <= data[index10:index11, 1]) and + np.all(-0.24 >= data[index10:index11, 1]) and + np.all(+0.05 >= data[index20:, 1]) and + np.all(-0.05 <= data[index20:, 1])) + + if __name__ == "__main__": test_global_inhibition("ensemble") test_global_inhibition("node") test_global_inhibition("f_of_t_node") + test_global_inhibition_from_optimised_out_passthrough_node() diff --git a/tests/builder/test_ensemble.py b/tests/builder/test_ensemble.py index 3b7282a..625a99d 100644 --- a/tests/builder/test_ensemble.py +++ b/tests/builder/test_ensemble.py @@ -248,33 +248,26 @@ def test_global_inhibition_sink(self): assert sink.target.obj is b_ens assert sink.target.port is ensemble.EnsembleInputPort.global_inhibition - def test_arbitrary_neuron_sink(self): - """We have no plan to support arbitrary connections to neurons.""" - with nengo.Network(): - a = nengo.Ensemble(100, 2) - b = nengo.Ensemble(200, 4) - - a_b = nengo.Connection(a, b.neurons, - transform=[[1.0, 0.5]]*199 + [[0.5, 1.0]]) - - # Create a model with the Ensemble for b in it - model = builder.Model() - b_ens = operators.EnsembleLIF(b) - model.object_operators[b] = b_ens - - # This should fail - with pytest.raises(NotImplementedError): - ensemble.get_neurons_sink(model, a_b) - - def test_neuron_sink(self): + @pytest.mark.parametrize("source", ("neurons", "value")) + def test_arbitrary_neuron_sink(self, source): """Test that standard connections to neurons return an appropriate sink. + + We have no plan to support arbitrary connections to neurons, but we + allow them at this stage because they may later become global + inhibition connections when we optimise out passthrough Nodes. """ with nengo.Network(): a = nengo.Ensemble(100, 2) b = nengo.Ensemble(100, 4) - a_b = nengo.Connection(a.neurons, b.neurons, transform=np.eye(100)) + if source == "neurons": + a_b = nengo.Connection(a.neurons, b.neurons, + transform=np.eye(100)) + else: + a_b = nengo.Connection(a, b.neurons, + transform=[[1.0, 0.5]]*99 + + [[0.5, 1.0]]) # Create a model with the Ensemble for b in it model = builder.Model() diff --git a/tests/builder/test_model.py b/tests/builder/test_model.py index 08d99f6..b3c2f46 100644 --- a/tests/builder/test_model.py +++ b/tests/builder/test_model.py @@ -407,6 +407,7 @@ class G(object): cm.add_connection(c, None, model.SignalParameters(), None, e, None, None) cm.add_connection(d, None, model.SignalParameters(), None, e, None, None) cm.add_connection(e, None, model.SignalParameters(), None, f, None, None) + cm._connections[f][None] = list() # Remove the sinkless filters removed = model.remove_sinkless_objects(cm, mock.Mock) diff --git a/tests/test_config.py b/tests/test_config.py index e446a5d..85e7c44 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,7 +12,7 @@ def test_add_spinnaker_params(): # Create a network with nengo.Network() as net: n_ft = nengo.Node(lambda t: [t, t**2]) - n_pt = nengo.Node(size_in=100) + ptn = nengo.Node(size_in=2) # Setting SpiNNaker-specific options before calling `add_spinnaker_params` # should fail. @@ -27,10 +27,11 @@ def test_add_spinnaker_params(): for param, value in [ ("n_cores_per_chip", 16), - ("n_chips", 4) + ("n_chips", 4), + ("optimize_out", False), ]: with pytest.raises(AttributeError) as excinfo: - setattr(net.config[n_ft], param, value) + setattr(net.config[ptn], param, value) for param, value in [ ("placer", lambda r, n, m, c: None), @@ -51,9 +52,10 @@ def test_add_spinnaker_params(): assert net.config[nengo.Node].function_of_time is False assert net.config[nengo.Node].function_of_time_period is None + assert net.config[nengo.Node].optimize_out is None - assert net.config[nengo.Node].n_cores_per_chip == None - assert net.config[nengo.Node].n_chips == None + assert net.config[nengo.Node].n_cores_per_chip is None + assert net.config[nengo.Node].n_chips is None assert net.config[Simulator].placer is par.place assert net.config[Simulator].placer_kwargs == {} diff --git a/tests/utils/test_utils_model.py b/tests/utils/test_utils_model.py new file mode 100644 index 0000000..3a4b742 --- /dev/null +++ b/tests/utils/test_utils_model.py @@ -0,0 +1,639 @@ +import mock +import nengo +import numpy as np +import pytest + +from nengo_spinnaker.utils import model as model_utils +from nengo_spinnaker.builder import model + +from nengo_spinnaker.builder.ensemble import EnsembleTransmissionParameters +from nengo_spinnaker.builder.node import ( + PassthroughNodeTransmissionParameters, NodeTransmissionParameters +) +from nengo_spinnaker.builder.ports import EnsembleInputPort + + +def test_get_force_removal_passnodes(): + """Passthrough Nodes in networks of passthrough Nodes that connect to + neurons need to be marked for removal. + """ + # Construct a network with three sets of passthrough Nodes + with nengo.Network() as model: + # Spurious Node that should be ignored + nengo.Node(lambda t: t) + + # First set of passthrough Nodes + # 0 --\ + # 1 --> 3 ---> 4 --> Neurons + # 2 --/ \--> 5 --> Ensemble + in_a = nengo.Node(np.zeros(100)) + out_a = nengo.Ensemble(100, 1) + + set_a = list(nengo.Node(size_in=100, label="A{}".format(n)) + for n in range(6)) + + for node in set_a[:3]: + nengo.Connection(in_a, node) + nengo.Connection(node, set_a[3]) + + for node in set_a[4:]: + nengo.Connection(set_a[3], node) + + nengo.Connection(set_a[4], out_a.neurons) + nengo.Connection(set_a[5], out_a, transform=np.zeros((1, 100))) + + # Second set of passthrough Nodes + # Neurons -> 0 --\ + # Value ---> 1 --> 2 --> Value + with nengo.Network(): + in_b = nengo.Ensemble(200, 1) + set_b = list(nengo.Node(size_in=1, label="B{}".format(n)) + for n in range(3)) + + nengo.Connection(in_b.neurons, set_b[0], + transform=np.ones((1, 200))) + nengo.Connection(in_b, set_b[1]) + + nengo.Connection(set_b[0], set_b[2]) + nengo.Connection(set_b[1], set_b[2]) + + nengo.Connection(set_b[2], in_b) + + # Third set of passthrough Nodes + # Ensemble -> 0 -> 1 -> Ensemble + set_c = list(nengo.Node(size_in=2, label="C{}".format(n)) + for n in range(2)) + in_c = nengo.Ensemble(100, 2) + + nengo.Connection(in_c, set_c[0]) + nengo.Connection(set_c[0], set_c[1]) + nengo.Connection(set_c[1], in_c) + + # Get whether the passthrough Nodes should be marked for removal + assert model_utils.get_force_removal_passnodes(model) == set(set_a + set_b) + + +def test_remove_operator_from_connection_map(): + """Test that operators are correctly removed from connection maps. + + We test the following model: + + O1 ------> O3 -----> O5 + /-----/ \ + / \----> O6 + O2 --> O4 + + Removing `O3' should result in: + + O1 --> O6 + / + /-> 05 + O2 --> O4 + """ + # Construct the operators + operators = [mock.Mock(name="O{}".format(i + 1)) for i in range(6)] + + # Create a connection map + cm = model.ConnectionMap() + + # Add the connection O1 to O3 + sps = model.SignalParameters(True, 6) + tps = PassthroughNodeTransmissionParameters(np.vstack([np.eye(3), + np.zeros((3, 3))])) + rps = model.ReceptionParameters(None, 6) + + cm.add_connection(source_object=operators[0], + source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=operators[2], + sink_port=None, + reception_parameters=rps) + + # Add the connection O2 to O3 + sps = model.SignalParameters(False, 6) + tps = PassthroughNodeTransmissionParameters(np.vstack([np.zeros((3, 3)), + np.eye(3)])) + rps = model.ReceptionParameters(None, 6) + + cm.add_connection(source_object=operators[1], + source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=operators[2], + sink_port=None, + reception_parameters=rps) + + # Add the connection O2 to O4 (with a custom keyspace) + sps = model.SignalParameters(False, 6, mock.Mock("Keyspace 1")) + tps = PassthroughNodeTransmissionParameters(np.vstack([np.eye(3), + np.eye(3)])) + rps = model.ReceptionParameters(None, 6) + + cm.add_connection(source_object=operators[1], + source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=operators[3], + sink_port=None, + reception_parameters=rps) + + # Add the connection O3 to O5 + sps = model.SignalParameters(False, 3) + tps = PassthroughNodeTransmissionParameters(np.hstack((np.zeros((3, 3)), + np.eye(3)))) + rps = model.ReceptionParameters(None, 3) + + cm.add_connection(source_object=operators[2], + source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=operators[4], + sink_port=None, + reception_parameters=rps) + + # Add the connection O3 to O6 + sps = model.SignalParameters(False, 3) + tps = PassthroughNodeTransmissionParameters(np.hstack([np.eye(3), + np.eye(3)]) * 2) + rps = model.ReceptionParameters(None, 3) + + cm.add_connection(source_object=operators[2], + source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=operators[5], + sink_port=None, + reception_parameters=rps) + + # Remove O3 from the connection map + model_utils.remove_operator_from_connection_map(cm, operators[2]) + + # Check that the received and transmitted signals are as expected + # FROM O1 + from_o1 = cm._connections[operators[0]] + assert len(from_o1) == 1 + assert len(from_o1[None]) == 1 + + ((signal_parameters, transmission_parameters), sinks) = from_o1[None][0] + assert signal_parameters == model.SignalParameters(True, 3, None) + assert transmission_parameters.transform.shape == (3, 3) + assert np.all(transmission_parameters.transform == np.eye(3)*2) + assert sinks == [(operators[5], None, model.ReceptionParameters(None, 3))] + + # FROM O2 + from_o2 = cm._connections[operators[1]] + assert len(from_o2) == 1 + assert len(from_o2[None]) == 3 + + for ((signal_parameters, transmission_parameters), sinks) in from_o2[None]: + if transmission_parameters.transform.shape == (3, 3): + assert (signal_parameters == + model.SignalParameters(False, 3, None)) + + if np.any(transmission_parameters.transform == 2.0): + # TO O6 + assert np.all(transmission_parameters.transform == + np.eye(3) * 2) + assert sinks == [(operators[5], None, + model.ReceptionParameters(None, 3))] + else: + # TO O5 + assert np.all(transmission_parameters.transform == + np.eye(3)) + assert sinks == [(operators[4], None, + model.ReceptionParameters(None, 3))] + else: + # TO O4 + assert transmission_parameters.transform.shape == (6, 3) + assert np.all(transmission_parameters.transform == + np.vstack([np.eye(3)]*2)) + assert sinks == [(operators[3], None, + model.ReceptionParameters(None, 6))] + + # We now add a connection from O4 to O6 with a custom keyspace. Removing + # O4 will fail because keyspaces can't be merged. + signal_params = model.SignalParameters(False, 1, mock.Mock("Keyspace 2")) + transmission_params = PassthroughNodeTransmissionParameters(1.0) + reception_params = model.ReceptionParameters(None, 1) + + cm.add_connection( + source_object=operators[3], + source_port=None, + signal_parameters=signal_params, + transmission_parameters=transmission_params, + sink_object=operators[5], + sink_port=None, + reception_parameters=reception_params + ) + + with pytest.raises(NotImplementedError) as err: + model_utils.remove_operator_from_connection_map( + cm, operators[3] + ) + assert "keyspace" in str(err.value).lower() + + +def test_remove_operator_from_connection_map_unforced(): + """Check that a calculation is made to determine whether it is better to + keep or remove an operator depending on the density of the outgoing + connections. In the example: + + /- G[0] + A[0] --\ /-- D[0] --\ /-- G[1] + A[1] --- B --- C --- D[1] --- E --- F --- G[2] + A[n] --/ \-- D[n] --/ \-- G[3] + \- G[n] + + B, C and F should be removed but E should be retained: + + /- G[0] + A[0] --- D[0] --\ /-- G[1] + A[1] --- D[1] --- E --- G[2] + A[n] --- D[n] --/ \-- G[3] + \- G[n] + """ + # Create the operators + D = 512 + SD = 16 + + op_A = [mock.Mock(name="A{}".format(i)) for i in range(D//SD)] + op_B = mock.Mock(name="B") + op_C = mock.Mock(name="C") + op_D = [mock.Mock(name="D{}".format(i)) for i in range(D//SD)] + op_E = mock.Mock(name="E") + op_F = mock.Mock(name="F") + op_G = [mock.Mock(name="G{}".format(i)) for i in range(D)] + + # Create a connection map + cm = model.ConnectionMap() + + # Create the fan-in connections + for sources, sink in ((op_A, op_B), (op_D, op_E)): + # Get the signal and reception parameters + sps = model.SignalParameters(True, D) + rps = model.ReceptionParameters(None, D) + + for i, source in enumerate(sources): + # Get the transform + transform = np.zeros((D, SD)) + transform[i*SD:(i+1)*SD, :] = np.eye(SD) + + # Get the parameters + tps = EnsembleTransmissionParameters(np.ones((1, SD)), transform) + + cm.add_connection(source_object=source, source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=sink, sink_port=None, + reception_parameters=rps) + + # Create the fan-out connection C to D[...] + # Get the signal and reception parameters + sps = model.SignalParameters(True, SD) + rps = model.ReceptionParameters(None, SD) + + for i, sink in enumerate(op_D): + # Get the transform + transform = np.zeros((SD, D)) + transform[:, i*SD:(i+1)*SD] = np.eye(SD) + + # Get the parameters + tps = PassthroughNodeTransmissionParameters(transform) + + cm.add_connection(source_object=op_C, source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=sink, sink_port=None, + reception_parameters=rps) + + # Create the connection B to C + sps = model.SignalParameters(True, D) + rps = model.ReceptionParameters(None, D) + tps = PassthroughNodeTransmissionParameters(np.eye(D)) + + cm.add_connection(source_object=op_B, source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=op_C, sink_port=None, + reception_parameters=rps) + + # Create the connection E to F + transform = np.zeros((D, D)) + for i in range(D): + for j in range(D): + transform[i, j] = i + j + + sps = model.SignalParameters(True, D) + rps = model.ReceptionParameters(None, D) + tps = PassthroughNodeTransmissionParameters(transform) + + cm.add_connection(source_object=op_E, source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=op_F, sink_port=None, + reception_parameters=rps) + + # Create the fan-out connections from F + sps = model.SignalParameters(True, SD) + rps = model.ReceptionParameters(None, SD) + + for i, sink in enumerate(op_G): + # Get the transform + transform = np.zeros((1, D)) + transform[:, i] = 1.0 + + # Get the parameters + tps = PassthroughNodeTransmissionParameters(transform) + + cm.add_connection(source_object=op_F, source_port=None, + signal_parameters=sps, + transmission_parameters=tps, + sink_object=sink, sink_port=None, + reception_parameters=rps) + + # Remove all of the passthrough Nodes, only E should be retained + assert model_utils.remove_operator_from_connection_map(cm, op_B, + force=False) + assert model_utils.remove_operator_from_connection_map(cm, op_C, + force=False) + assert not model_utils.remove_operator_from_connection_map(cm, op_E, + force=False) + assert model_utils.remove_operator_from_connection_map(cm, op_F, + force=False) + + # Check that each A has only one outgoing signal and that it terminates at + # the paired D. Additionally check that each D has only one outgoing + # signal and that it terminates at E. + for a, d in zip(op_A, op_D): + # Connections from A[n] + from_a = cm._connections[a] + assert len(from_a) == 1 + assert len(from_a[None]) == 1 + + ((signal_parameters, transmission_parameters), sinks) = from_a[None][0] + assert signal_parameters == model.SignalParameters(True, SD, None) + assert transmission_parameters.transform.shape == (SD, SD) + assert np.all(transmission_parameters.transform == np.eye(SD)) + assert sinks == [(d, None, model.ReceptionParameters(None, SD))] + + # Connection(s) from D[n] + from_d = cm._connections[d] + assert len(from_d) == 1 + assert len(from_d[None]) == 1 + + ((signal_parameters, transmission_parameters), sinks) = from_d[None][0] + assert signal_parameters == model.SignalParameters(True, D, None) + assert transmission_parameters.transform.shape == (D, SD) + + # Check that there are many connections from E + from_e = cm._connections[op_E] + assert len(from_e) == 1 + print(from_e[None][0].parameters[1].transform) + assert len(from_e[None]) == D + + +class TestCombineTransmissionAndReceptionParameters(object): + """Test the correct combination of transmission parameters.""" + # Ensemble and Passthrough Node to StandardInput or Neurons - NOT global + # inhibition + @pytest.mark.parametrize("final_port", (model.InputPort.standard, + EnsembleInputPort.neurons)) + def test_ens_to_x(self, final_port): + # Create the ingoing connection parameters + in_transmission_params = EnsembleTransmissionParameters( + np.random.uniform(size=(100, 10)), 1.0 + ) + + # Create the outgoing connection parameters + out_transmission_params = PassthroughNodeTransmissionParameters( + np.hstack([np.eye(5), np.zeros((5, 5))]) + ) + + # Combine the parameter sets + new_tps, new_in_port = model_utils._combine_transmission_params( + in_transmission_params, + out_transmission_params, + final_port + ) + + # Check that all the parameters are correct + assert np.all(new_tps.untransformed_decoders == + in_transmission_params.untransformed_decoders) + assert np.all(new_tps.transform == + out_transmission_params.transform) + assert new_tps.decoders.shape == (5, 100) + assert new_in_port is final_port + + # Node and Passthrough Node to Standard Input or Neurons - NOT global + # inhibition + @pytest.mark.parametrize("passthrough", (False, True)) + @pytest.mark.parametrize("final_port", (model.InputPort.standard, + EnsembleInputPort.neurons)) + def test_node_to_x(self, passthrough, final_port): + # Create the ingoing connection parameters + if not passthrough: + in_transmission_params = NodeTransmissionParameters( + slice(10, 15), + mock.Mock(), + np.random.uniform(size=(10, 5)), + ) + else: + in_transmission_params = PassthroughNodeTransmissionParameters( + np.random.uniform(size=(10, 5)), + ) + + # Create the outgoing connection parameters + out_transmission_params = PassthroughNodeTransmissionParameters( + np.hstack((np.zeros((5, 5)), np.eye(5))) + ) + + # Combine the parameter sets + new_tps, new_in_port = model_utils._combine_transmission_params( + in_transmission_params, + out_transmission_params, + final_port + ) + + # Check that all the parameters are correct + if not passthrough: + assert new_tps.pre_slice == in_transmission_params.pre_slice + assert new_tps.function is in_transmission_params.function + + assert np.all(new_tps.transform == np.dot( + out_transmission_params.transform, + in_transmission_params.transform + )) + assert new_tps.transform.shape == (5, 5) + assert new_in_port is final_port + + # Ensemble and Passthrough Node to Neurons (Global Inhibition) + def test_ens_to_gi(self): + # Create the ingoing connection parameters + in_transmission_params = EnsembleTransmissionParameters( + np.random.uniform(size=(100, 7)), 1.0 + ) + + # Create the outgoing connection parameters + out_transmission_params = PassthroughNodeTransmissionParameters( + np.ones((200, 7)) + ) + + # Combine the parameter sets + new_tps, new_in_port = model_utils._combine_transmission_params( + in_transmission_params, + out_transmission_params, + EnsembleInputPort.neurons + ) + + # Check that all the parameters are correct + assert np.all(new_tps.transform == 1.0) + assert new_tps.transform.shape == (1, 7) + assert new_tps.decoders.shape == (1, 100) + assert new_in_port is EnsembleInputPort.global_inhibition + + # Node and Passthrough Node to Neurons (Global Inhibition) + @pytest.mark.parametrize("passthrough", (False, True)) + def test_node_to_gi(self, passthrough): + # Create the ingoing connection parameters + if not passthrough: + in_transmission_params = NodeTransmissionParameters( + slice(10, 20), + mock.Mock(), + np.ones((100, 1)) + ) + else: + in_transmission_params = PassthroughNodeTransmissionParameters( + np.ones((100, 1)) + ) + + # Create the outgoing connection parameters + out_transmission_params = PassthroughNodeTransmissionParameters( + np.eye(100) + ) + + # Combine the parameter sets + new_tps, new_in_port = model_utils._combine_transmission_params( + in_transmission_params, out_transmission_params, + EnsembleInputPort.neurons + ) + + # Check that all the parameters are correct + if not passthrough: + assert new_tps.pre_slice == in_transmission_params.pre_slice + assert new_tps.function is in_transmission_params.function + + assert np.all(new_tps.transform == np.dot( + out_transmission_params.transform, + in_transmission_params.transform + )[0]) + assert new_tps.transform.shape[0] == 1 + assert new_in_port is EnsembleInputPort.global_inhibition + + @pytest.mark.parametrize("from_type", ("ensemble", "node", "ptn")) + @pytest.mark.parametrize("final_port", + (model.InputPort.standard, + EnsembleInputPort.neurons, + EnsembleInputPort.global_inhibition)) + def test_x_to_x_optimize_out(self, from_type, final_port): + # Create the ingoing connection parameters + in_transmission_params = { + "ensemble": EnsembleTransmissionParameters( + np.random.uniform(size=(100, 1)), + np.array([[1.0], [0.0]]) + ), + "node": NodeTransmissionParameters( + slice(10, 20), + mock.Mock(), + np.array([[1.0], [0.0]]) + ), + "ptn": PassthroughNodeTransmissionParameters( + np.array([[1.0], [0.0]]) + ), + }[from_type] + + # Create the outgoing connection parameters + out_transmission_params = PassthroughNodeTransmissionParameters( + np.array([[0.0, 0.0], [0.0, 1.0]]) + ) + + # Combine the parameter sets + new_tps, new_in_port = model_utils._combine_transmission_params( + in_transmission_params, + out_transmission_params, + final_port + ) + + # Check that the connection is optimised out + assert new_tps is None + assert new_in_port is None + + def test_unknown_to_x(self): + # Create the ingoing connection parameters + in_transmission_params = mock.Mock() + in_transmission_params.transform = 1.0 + + # Create the outgoing connection parameters + out_transmission_params = PassthroughNodeTransmissionParameters( + np.array([[0.0, 0.0], [0.0, 1.0]]) + ) + + # Combine the parameter sets + with pytest.raises(NotImplementedError): + model_utils._combine_transmission_params( + in_transmission_params, + out_transmission_params, + None + ) + + # Test combining reception parameters + def test_combine_none_and_lowpass_filter(self): + # Create the ingoing reception parameters + reception_params_a = model.ReceptionParameters(nengo.Lowpass(0.05), 1) + + # Create the outgoing reception parameters + reception_params_b = model.ReceptionParameters(None, 3) + + # Combine the parameter each way round + for a, b in ((reception_params_a, reception_params_b), + (reception_params_a, reception_params_b)): + new_rps = model_utils._combine_reception_params(a, b) + + # Check filter type + assert new_rps.filter == reception_params_a.filter + + # Check width is the width of the receiving item + assert new_rps.width == b.width + + def test_combine_linear_and_linear_filter(self): + # Create the ingoing reception parameters + reception_params_a = model.ReceptionParameters(nengo.Lowpass(0.05), 1) + + # Create the outgoing reception parameters + reception_params_b = model.ReceptionParameters(nengo.Lowpass(0.01), 5) + + # Combine the parameter each way round + for a, b in ((reception_params_a, reception_params_b), + (reception_params_a, reception_params_b)): + new_rps = model_utils._combine_reception_params(a, b) + + # Check filter type + synapse = new_rps.filter + assert synapse.num == [1] + assert np.all(synapse.den == [0.05 * 0.01, 0.05 + 0.01, 1]) + + # Check width is the width of the receiving item + assert new_rps.width == b.width + + def test_combine_unknown_filter(self): + # Create the ingoing reception parameters + reception_params_a = model.ReceptionParameters(nengo.Lowpass(0.05), 1) + + # Create the outgoing reception parameters + reception_params_b = model.ReceptionParameters(mock.Mock(), 1) + + # Combine the parameter each way round + for a, b in ((reception_params_a, reception_params_b), + (reception_params_a, reception_params_b)): + with pytest.raises(NotImplementedError): + new_rps = model_utils._combine_reception_params(a, b)