diff --git a/nengo_spinnaker/builder/builder.py b/nengo_spinnaker/builder/builder.py index 54386d0..67b3c23 100644 --- a/nengo_spinnaker/builder/builder.py +++ b/nengo_spinnaker/builder/builder.py @@ -344,13 +344,43 @@ def make_netlist(self, *args, **kwargs): # Construct nets from the signals nets = list() - for signal in self.connection_map.get_signals(): + for signal, transmission_parameters in \ + self.connection_map.get_signals(): # Get the source and sink vertices - sources = operator_vertices[signal.source] + original_sources = operator_vertices[signal.source] + if not isinstance(original_sources, collections.Iterable): + original_sources = (original_sources, ) + + # Filter out any sources which have an `accepts_signal` method and + # return False when this is called with the signal and transmission + # parameters. + sources = list() + for source in original_sources: + # For each source which either doesn't have a + # `transmits_signal` method or returns True when this is called + # with the signal and transmission parameters add a new net to + # the netlist. + if (hasattr(source, "transmits_signal") and not + source.transmits_signal(signal, + transmission_parameters)): + pass # This source is ignored + else: + # Add the source to the final list of sources + sources.append(source) sinks = collections_ext.flatinsertionlist() for sink in signal.sinks: - sinks.append(operator_vertices[sink]) + # Get all the sink vertices + sink_vertices = operator_vertices[sink] + if not isinstance(sink_vertices, collections.Iterable): + sink_vertices = (sink_vertices, ) + + # Include any sinks which either don't have an `accepts_signal` + # method or return true when this is called with the signal and + # transmission parameters. + sinks.append(s for s in sink_vertices if + not hasattr(s, "accepts_signal") or + s.accepts_signal(signal, transmission_parameters)) # Create the net(s) nets.append(NMNet(sources, list(sinks), diff --git a/nengo_spinnaker/builder/connection.py b/nengo_spinnaker/builder/connection.py index 6b088dd..a2691ef 100644 --- a/nengo_spinnaker/builder/connection.py +++ b/nengo_spinnaker/builder/connection.py @@ -1,4 +1,6 @@ import nengo +import numpy as np + from .builder import Model, ObjectPort, spec from .model import ReceptionParameters, InputPort, OutputPort @@ -23,3 +25,95 @@ def build_generic_reception_params(model, conn): """ # Just extract the synapse from the connection. return ReceptionParameters(conn.synapse, conn.post_obj.size_in) + + +class EnsembleTransmissionParameters(object): + """Transmission parameters for a connection originating at an Ensemble. + + Attributes + ---------- + decoders : array + Decoders to use for the connection. + """ + def __init__(self, decoders, transform): + # Copy the decoders + self.untransformed_decoders = np.array(decoders) + self.transform = np.array(transform) + + # Compute and store the transformed decoders + self.decoders = np.dot(transform, decoders.T).T + + # Make the arrays read-only + self.untransformed_decoders.flags['WRITEABLE'] = False + self.transform.flags['WRITEABLE'] = False + self.decoders.flags['WRITEABLE'] = False + + def __ne__(self, other): + return not (self == other) + + def __eq__(self, other): + # Equal iff. the objects are of the same type + if type(self) is not type(other): + return False + + # Equal iff. the decoders are the same shape + if self.decoders.shape != other.decoders.shape: + return False + + # Equal iff. the decoder values are the same + if np.any(self.decoders != other.decoders): + return False + + return True + + +class PassthroughNodeTransmissionParameters(object): + """Parameters describing connections which originate from pass through + Nodes. + """ + def __init__(self, transform): + # Store the parameters, copying the transform + self.transform = np.array(transform) + + def __ne__(self, other): + return not (self == other) + + def __eq__(self, other): + # Equivalent if the same type + if type(self) is not type(other): + return False + + # and the transforms are equivalent + if (self.transform.shape != other.transform.shape or + np.any(self.transform != other.transform)): + return False + + return True + + +class NodeTransmissionParameters(PassthroughNodeTransmissionParameters): + """Parameters describing connections which originate from Nodes.""" + def __init__(self, pre_slice, function, transform): + # Store the parameters + super(NodeTransmissionParameters, self).__init__(transform) + self.pre_slice = pre_slice + self.function = function + + def __hash__(self): + # Hash by ID + return hash(id(self)) + + def __eq__(self, other): + # Parent equivalence + if not super(NodeTransmissionParameters, self).__eq__(other): + return False + + # Equivalent if the pre_slices are exactly the same + if self.pre_slice != other.pre_slice: + return False + + # Equivalent if the functions are the same + if self.function is not other.function: + return False + + return True diff --git a/nengo_spinnaker/builder/ensemble.py b/nengo_spinnaker/builder/ensemble.py index fbb2988..2edfe46 100644 --- a/nengo_spinnaker/builder/ensemble.py +++ b/nengo_spinnaker/builder/ensemble.py @@ -9,6 +9,7 @@ import numpy as np from .builder import BuiltConnection, Model, ObjectPort, spec +from .connection import EnsembleTransmissionParameters from .model import InputPort from .ports import EnsembleInputPort from .. import operators @@ -137,37 +138,6 @@ def build_lif(model, ens): model.object_operators[ens] = operators.EnsembleLIF(ens) -class EnsembleTransmissionParameters(object): - """Transmission parameters for a connection originating at an Ensemble. - - Attributes - ---------- - decoders : array - Decoders to use for the connection (including the transform). - """ - def __init__(self, decoders): - # Copy the decoders - self.decoders = np.array(decoders) - - def __ne__(self, other): - return not (self == other) - - def __eq__(self, other): - # Equal iff. the objects are of the same type - if type(self) is not type(other): - return False - - # Equal iff. the decoders are the same shape - if self.decoders.shape != other.decoders.shape: - return False - - # Equal iff. the decoder values are the same - if np.any(self.decoders != other.decoders): - return False - - return True - - @Model.transmission_parameter_builders.register(nengo.Ensemble) def build_from_ensemble_connection(model, conn): """Build the parameters object for a connection from an Ensemble.""" @@ -198,10 +168,7 @@ def build_from_ensemble_connection(model, conn): np.all(transform[0, :] == transform[1:, :])): transform = np.array([transform[0]]) - # Multiply the decoders by the transform and return this as the - # transmission parameters. - full_decoders = np.dot(transform, decoders.T).T - return EnsembleTransmissionParameters(full_decoders) + return EnsembleTransmissionParameters(decoders, transform) @Model.transmission_parameter_builders.register(nengo.ensemble.Neurons) diff --git a/nengo_spinnaker/builder/model.py b/nengo_spinnaker/builder/model.py index 9e032b8..766bbde 100644 --- a/nengo_spinnaker/builder/model.py +++ b/nengo_spinnaker/builder/model.py @@ -156,14 +156,14 @@ def get_signals(self): # For each source object and set of sinks yield a new signal for source, port_conns in iteritems(self._connections): # For each connection look at the sinks and the signal parameters - for (sig_pars, _), par_sinks in chain(*itervalues(port_conns)): + for (sig_pars, transmission_pars), par_sinks in \ + chain(*itervalues(port_conns)): # Create a signal using these parameters - yield Signal( - source, - (ps.sink_object for ps in par_sinks), # Extract the sinks - sig_pars.keyspace, - sig_pars.weight - ) + yield (Signal(source, + (ps.sink_object for ps in par_sinks), # Sinks + sig_pars.keyspace, + sig_pars.weight), + transmission_pars) class OutputPort(enum.Enum): diff --git a/nengo_spinnaker/builder/node.py b/nengo_spinnaker/builder/node.py index c4808d2..9530f9c 100644 --- a/nengo_spinnaker/builder/node.py +++ b/nengo_spinnaker/builder/node.py @@ -6,6 +6,8 @@ from nengo_spinnaker.builder.builder import ObjectPort, spec, Model from nengo_spinnaker.builder.model import InputPort, OutputPort +from .connection import (PassthroughNodeTransmissionParameters, + NodeTransmissionParameters) from nengo_spinnaker.operators import Filter, ValueSink, ValueSource from nengo_spinnaker.utils.config import getconfig @@ -294,58 +296,6 @@ def build_node_transmission_parameters(model, conn): return PassthroughNodeTransmissionParameters(transform) -class PassthroughNodeTransmissionParameters(object): - """Parameters describing connections which originate from pass through - Nodes. - """ - def __init__(self, transform): - # Store the parameters, copying the transform - self.transform = np.array(transform) - - def __ne__(self, other): - return not (self == other) - - def __eq__(self, other): - # Equivalent if the same type - if type(self) is not type(other): - return False - - # and the transforms are equivalent - if (self.transform.shape != other.transform.shape or - np.any(self.transform != other.transform)): - return False - - return True - - -class NodeTransmissionParameters(PassthroughNodeTransmissionParameters): - """Parameters describing connections which originate from Nodes.""" - def __init__(self, pre_slice, function, transform): - # Store the parameters - super(NodeTransmissionParameters, self).__init__(transform) - self.pre_slice = pre_slice - self.function = function - - def __hash__(self): - # Hash by ID - return hash(id(self)) - - def __eq__(self, other): - # Parent equivalence - if not super(NodeTransmissionParameters, self).__eq__(other): - return False - - # Equivalent if the pre_slices are exactly the same - if self.pre_slice != other.pre_slice: - return False - - # Equivalent if the functions are the same - if self.function is not other.function: - return False - - return True - - class InputNode(nengo.Node): """Node which queries the IO controller for the input to a Node from.""" def __init__(self, node, controller): diff --git a/tests/builder/test_builder.py b/tests/builder/test_builder.py index e0dfedc..888db84 100644 --- a/tests/builder/test_builder.py +++ b/tests/builder/test_builder.py @@ -594,6 +594,14 @@ def test_multiple_sink_vertices(self): operator_b.make_vertices.return_value = \ netlistspec([vertex_b0, vertex_b1], load_fn_b) + # Create a third operator, which won't accept the signal + vertex_c = mock.Mock(name="vertex C") + vertex_c.accepts_signal.side_effect = lambda _, __: False + + object_c = mock.Mock(name="object C") + operator_c = mock.Mock(name="operator C") + operator_c.make_vertices.return_value = netlistspec(vertex_c) + # Create a signal between the operators keyspace = mock.Mock(name="keyspace") keyspace.length = 32 @@ -603,14 +611,24 @@ def test_multiple_sink_vertices(self): model = Model() model.object_operators[object_a] = operator_a model.object_operators[object_b] = operator_b + model.object_operators[object_c] = operator_c model.connection_map.add_connection( operator_a, None, signal_ab_parameters, None, operator_b, None, None ) + model.connection_map.add_connection( + operator_a, None, signal_ab_parameters, None, + operator_c, None, None + ) netlist = model.make_netlist() + # Check that the "accepts_signal" method of vertex_c was called with + # reasonable arguments + assert vertex_c.accepts_signal.called + # Check that the netlist is as expected - assert set(netlist.vertices) == set([vertex_a, vertex_b0, vertex_b1]) + assert set(netlist.vertices) == set( + [vertex_a, vertex_b0, vertex_b1, vertex_c]) assert len(netlist.nets) == 1 for net in netlist.nets: assert net.sources == [vertex_a] @@ -625,9 +643,20 @@ def test_multiple_source_vertices(self): """Test that each of the vertices associated with a source is correctly included in the sources of a net. """ + class MyVertexSlice(VertexSlice): + def __init__(self, *args, **kwargs): + super(MyVertexSlice, self).__init__(*args, **kwargs) + self.args = None + + def transmits_signal(self, signal_parameters, + transmission_parameters): + self.args = (signal_parameters, transmission_parameters) + return False + # Create the first operator vertex_a0 = VertexSlice(slice(0, 1)) vertex_a1 = VertexSlice(slice(1, 2)) + vertex_a2 = MyVertexSlice(slice(2, 3)) load_fn_a = mock.Mock(name="load function A") pre_fn_a = mock.Mock(name="pre function A") post_fn_a = mock.Mock(name="post function A") @@ -635,7 +664,8 @@ def test_multiple_source_vertices(self): object_a = mock.Mock(name="object A") operator_a = mock.Mock(name="operator A") operator_a.make_vertices.return_value = \ - netlistspec([vertex_a0, vertex_a1], load_fn_a, pre_fn_a, post_fn_a) + netlistspec([vertex_a0, vertex_a1, vertex_a2], + load_fn_a, pre_fn_a, post_fn_a) # Create the second operator vertex_b = Vertex() @@ -662,10 +692,16 @@ def test_multiple_source_vertices(self): netlist = model.make_netlist() # Check that the netlist is as expected - assert set(netlist.vertices) == set([vertex_a0, vertex_a1, vertex_b]) + assert set(netlist.vertices) == set([vertex_a0, vertex_a1, + vertex_a2, vertex_b]) assert len(netlist.nets) == 1 for net in netlist.nets: assert net.sources == [vertex_a0, vertex_a1] assert net.sinks == [vertex_b] - assert netlist.groups == [set([vertex_a0, vertex_a1])] + assert netlist.groups == [set([vertex_a0, vertex_a1, vertex_a2])] + + # Check that `transmit_signal` was called correctly + sig, tp = vertex_a2.args + assert sig.keyspace is keyspace + assert tp is None diff --git a/tests/builder/test_connection.py b/tests/builder/test_connection.py index e839b6d..355d0f3 100644 --- a/tests/builder/test_connection.py +++ b/tests/builder/test_connection.py @@ -1,5 +1,6 @@ import mock import nengo +import numpy as np from nengo_spinnaker.builder.builder import Model from nengo_spinnaker.builder.model import InputPort, OutputPort @@ -7,6 +8,9 @@ generic_source_getter, generic_sink_getter, build_generic_reception_params, + EnsembleTransmissionParameters, + PassthroughNodeTransmissionParameters, + NodeTransmissionParameters ) @@ -66,3 +70,57 @@ def test_build_standard_reception_params(): # Build the transmission parameters params = build_generic_reception_params(None, a_b) assert params.filter is a_b.synapse + + +class TestEnsembleTransmissionParameters(object): + def test_eq_ne(self): + """Create a series of EnsembleTransmissionParameters and ensure that + they only report equal when they are. + """ + class MyETP(EnsembleTransmissionParameters): + pass + + tp1 = EnsembleTransmissionParameters(np.ones((3, 3)), np.eye(3)) + tp2 = EnsembleTransmissionParameters(np.ones((1, 1)), np.eye(1)) + tp3 = EnsembleTransmissionParameters(np.eye(3), np.eye(3)) + tp4 = MyETP(np.ones((3, 3)), np.eye(3)) + + assert tp1 != tp2 + assert tp1 != tp3 + assert tp1 != tp4 + + tp5 = EnsembleTransmissionParameters(np.ones((3, 3)), np.eye(3)) + assert tp1 == tp5 + + tp6 = EnsembleTransmissionParameters(np.ones((3, 1)), np.ones((3, 1))) + assert tp1 == tp6 + + +class TestNodeTransmissionParameters(object): + def test_eq_ne(self): + class MyNTP(NodeTransmissionParameters): + pass + + # NodeTransmissionParameters are only equivalent if they are of the + # same type, they share the same pre_slice and transform. + pars = [ + (NodeTransmissionParameters, (slice(0, 5), None, np.ones((5, 5)))), + (NodeTransmissionParameters, (slice(None), None, np.ones((5, 5)))), + (NodeTransmissionParameters, (slice(0, 5), None, np.eye(5))), + (NodeTransmissionParameters, (slice(0, 5), None, np.ones((1, 1)))), + (NodeTransmissionParameters, + (slice(0, 5), lambda x: x, np.ones((5, 5)))), + (MyNTP, (slice(0, 5), None, np.ones((5, 5)))), + ] + ntps = [cls(*args) for cls, args in pars] + + # Check the inequivalence works + for a in ntps: + for b in ntps: + if a is not b: + assert a != b + + # Check that equivalence works + for a, b in zip(ntps, [cls(*args) for cls, args in pars]): + assert a is not b + assert a == b diff --git a/tests/builder/test_ensemble.py b/tests/builder/test_ensemble.py index e33e97e..ff384dd 100644 --- a/tests/builder/test_ensemble.py +++ b/tests/builder/test_ensemble.py @@ -287,28 +287,6 @@ def test_neuron_sink(self): assert sink.target.port is ensemble.EnsembleInputPort.neurons -class TestEnsembleTransmissionParameters(object): - def test_eq_ne(self): - """Create a series of EnsembleTransmissionParameters and ensure that - they only report equal when they are. - """ - class MyETP(ensemble.EnsembleTransmissionParameters): - pass - - tp1 = ensemble.EnsembleTransmissionParameters(np.ones((3, 3))) - tp2 = ensemble.EnsembleTransmissionParameters(np.ones((1, 1))) - tp3 = ensemble.EnsembleTransmissionParameters(np.eye(3)) - tp4 = MyETP(np.ones((3, 3))) - - assert tp1 != tp2 - assert tp1 != tp3 - assert tp1 != tp4 - - tp5 = ensemble.EnsembleTransmissionParameters(np.ones((3, 3))) - - assert tp1 == tp5 - - class TestBuildFromEnsembleConnection(object): """Test the construction of parameters that describe connections from Ensembles. diff --git a/tests/builder/test_model.py b/tests/builder/test_model.py index f85f2c4..a154e3f 100644 --- a/tests/builder/test_model.py +++ b/tests/builder/test_model.py @@ -291,23 +291,26 @@ def test_get_signals(self): # Add the connections cm = model.ConnectionMap() + tp_a = mock.Mock(name="Transmission Parameters A") + tp_b = mock.Mock(name="Transmission Parameters B") + cm.add_connection( obj_a, None, model.SignalParameters(weight=3, keyspace=ks_abc), - None, obj_b, None, None + tp_a, obj_b, None, None ) cm.add_connection( obj_a, None, model.SignalParameters(weight=3, keyspace=ks_abc), - None, obj_c, None, None + tp_a, obj_c, None, None ) cm.add_connection( obj_c, None, model.SignalParameters(weight=5, keyspace=ks_cb), - None, obj_b, None, None + tp_b, obj_b, None, None ) # Get the signals, this should be a list of two signals signals = list(cm.get_signals()) assert len(signals) == 2 - for signal in signals: + for signal, transmission_params in signals: if signal.source is obj_a: # Assert the sinks are correct assert len(signal.sinks) == 2 @@ -319,9 +322,15 @@ def test_get_signals(self): # Assert the weight is correct assert signal.weight == 3 + + # Assert the correct paired transmission parameters are used. + assert transmission_params is tp_a else: # Source should be C, sink B assert signal.source is obj_c assert signal.sinks == [obj_b] assert signal.keyspace is ks_cb assert signal.weight == 5 + + # Assert the correct paired transmission parameters are used. + assert transmission_params is tp_b diff --git a/tests/builder/test_node.py b/tests/builder/test_node.py index ff93fa1..b4b6d6c 100644 --- a/tests/builder/test_node.py +++ b/tests/builder/test_node.py @@ -8,8 +8,8 @@ from nengo_spinnaker.builder import Model from nengo_spinnaker.builder.model import OutputPort, InputPort from nengo_spinnaker.builder.node import ( - NodeIOController, InputNode, OutputNode, NodeTransmissionParameters, - PassthroughNodeTransmissionParameters, build_node_transmission_parameters + NodeIOController, InputNode, OutputNode, + build_node_transmission_parameters ) from nengo_spinnaker.operators import ValueSink @@ -605,36 +605,6 @@ def test_build_passthrough_node_global_inhibition(self): assert params.transform.shape == (1, 5) -class TestNodeTransmissionParameters(object): - def test_eq_ne(self): - class MyNTP(NodeTransmissionParameters): - pass - - # NodeTransmissionParameters are only equivalent if they are of the - # same type, they share the same pre_slice and transform. - pars = [ - (NodeTransmissionParameters, (slice(0, 5), None, np.ones((5, 5)))), - (NodeTransmissionParameters, (slice(None), None, np.ones((5, 5)))), - (NodeTransmissionParameters, (slice(0, 5), None, np.eye(5))), - (NodeTransmissionParameters, (slice(0, 5), None, np.ones((1, 1)))), - (NodeTransmissionParameters, - (slice(0, 5), lambda x: x, np.ones((5, 5)))), - (MyNTP, (slice(0, 5), None, np.ones((5, 5)))), - ] - ntps = [cls(*args) for cls, args in pars] - - # Check the inequivalence works - for a in ntps: - for b in ntps: - if a is not b: - assert a != b - - # Check that equivalence works - for a, b in zip(ntps, [cls(*args) for cls, args in pars]): - assert a is not b - assert a == b - - class TestInputNode(object): def test_init(self): """Test creating an new InputNode from an existing Node, this should