Skip to content

Commit

Permalink
Merge branch 'master' into multisource-nets
Browse files Browse the repository at this point in the history
  • Loading branch information
mundya committed Dec 8, 2015
2 parents 4593e1a + ceb5a0b commit ee48b82
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 159 deletions.
36 changes: 33 additions & 3 deletions nengo_spinnaker/builder/builder.py
Expand Up @@ -344,13 +344,43 @@ def make_netlist(self, *args, **kwargs):

# Construct nets from the signals
nets = list()
for signal in self.connection_map.get_signals():
for signal, transmission_parameters in \
self.connection_map.get_signals():
# Get the source and sink vertices
sources = operator_vertices[signal.source]
original_sources = operator_vertices[signal.source]
if not isinstance(original_sources, collections.Iterable):
original_sources = (original_sources, )

# Filter out any sources which have an `accepts_signal` method and
# return False when this is called with the signal and transmission
# parameters.
sources = list()
for source in original_sources:
# For each source which either doesn't have a
# `transmits_signal` method or returns True when this is called
# with the signal and transmission parameters add a new net to
# the netlist.
if (hasattr(source, "transmits_signal") and not
source.transmits_signal(signal,
transmission_parameters)):
pass # This source is ignored
else:
# Add the source to the final list of sources
sources.append(source)

sinks = collections_ext.flatinsertionlist()
for sink in signal.sinks:
sinks.append(operator_vertices[sink])
# Get all the sink vertices
sink_vertices = operator_vertices[sink]
if not isinstance(sink_vertices, collections.Iterable):
sink_vertices = (sink_vertices, )

# Include any sinks which either don't have an `accepts_signal`
# method or return true when this is called with the signal and
# transmission parameters.
sinks.append(s for s in sink_vertices if
not hasattr(s, "accepts_signal") or
s.accepts_signal(signal, transmission_parameters))

# Create the net(s)
nets.append(NMNet(sources, list(sinks),
Expand Down
94 changes: 94 additions & 0 deletions nengo_spinnaker/builder/connection.py
@@ -1,4 +1,6 @@
import nengo
import numpy as np

from .builder import Model, ObjectPort, spec
from .model import ReceptionParameters, InputPort, OutputPort

Expand All @@ -23,3 +25,95 @@ def build_generic_reception_params(model, conn):
"""
# Just extract the synapse from the connection.
return ReceptionParameters(conn.synapse, conn.post_obj.size_in)


class EnsembleTransmissionParameters(object):
"""Transmission parameters for a connection originating at an Ensemble.
Attributes
----------
decoders : array
Decoders to use for the connection.
"""
def __init__(self, decoders, transform):
# Copy the decoders
self.untransformed_decoders = np.array(decoders)
self.transform = np.array(transform)

# Compute and store the transformed decoders
self.decoders = np.dot(transform, decoders.T).T

# Make the arrays read-only
self.untransformed_decoders.flags['WRITEABLE'] = False
self.transform.flags['WRITEABLE'] = False
self.decoders.flags['WRITEABLE'] = False

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equal iff. the objects are of the same type
if type(self) is not type(other):
return False

# Equal iff. the decoders are the same shape
if self.decoders.shape != other.decoders.shape:
return False

# Equal iff. the decoder values are the same
if np.any(self.decoders != other.decoders):
return False

return True


class PassthroughNodeTransmissionParameters(object):
"""Parameters describing connections which originate from pass through
Nodes.
"""
def __init__(self, transform):
# Store the parameters, copying the transform
self.transform = np.array(transform)

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equivalent if the same type
if type(self) is not type(other):
return False

# and the transforms are equivalent
if (self.transform.shape != other.transform.shape or
np.any(self.transform != other.transform)):
return False

return True


class NodeTransmissionParameters(PassthroughNodeTransmissionParameters):
"""Parameters describing connections which originate from Nodes."""
def __init__(self, pre_slice, function, transform):
# Store the parameters
super(NodeTransmissionParameters, self).__init__(transform)
self.pre_slice = pre_slice
self.function = function

def __hash__(self):
# Hash by ID
return hash(id(self))

def __eq__(self, other):
# Parent equivalence
if not super(NodeTransmissionParameters, self).__eq__(other):
return False

# Equivalent if the pre_slices are exactly the same
if self.pre_slice != other.pre_slice:
return False

# Equivalent if the functions are the same
if self.function is not other.function:
return False

return True
37 changes: 2 additions & 35 deletions nengo_spinnaker/builder/ensemble.py
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from .builder import BuiltConnection, Model, ObjectPort, spec
from .connection import EnsembleTransmissionParameters
from .model import InputPort
from .ports import EnsembleInputPort
from .. import operators
Expand Down Expand Up @@ -137,37 +138,6 @@ def build_lif(model, ens):
model.object_operators[ens] = operators.EnsembleLIF(ens)


class EnsembleTransmissionParameters(object):
"""Transmission parameters for a connection originating at an Ensemble.
Attributes
----------
decoders : array
Decoders to use for the connection (including the transform).
"""
def __init__(self, decoders):
# Copy the decoders
self.decoders = np.array(decoders)

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equal iff. the objects are of the same type
if type(self) is not type(other):
return False

# Equal iff. the decoders are the same shape
if self.decoders.shape != other.decoders.shape:
return False

# Equal iff. the decoder values are the same
if np.any(self.decoders != other.decoders):
return False

return True


@Model.transmission_parameter_builders.register(nengo.Ensemble)
def build_from_ensemble_connection(model, conn):
"""Build the parameters object for a connection from an Ensemble."""
Expand Down Expand Up @@ -198,10 +168,7 @@ def build_from_ensemble_connection(model, conn):
np.all(transform[0, :] == transform[1:, :])):
transform = np.array([transform[0]])

# Multiply the decoders by the transform and return this as the
# transmission parameters.
full_decoders = np.dot(transform, decoders.T).T
return EnsembleTransmissionParameters(full_decoders)
return EnsembleTransmissionParameters(decoders, transform)


@Model.transmission_parameter_builders.register(nengo.ensemble.Neurons)
Expand Down
14 changes: 7 additions & 7 deletions nengo_spinnaker/builder/model.py
Expand Up @@ -156,14 +156,14 @@ def get_signals(self):
# For each source object and set of sinks yield a new signal
for source, port_conns in iteritems(self._connections):
# For each connection look at the sinks and the signal parameters
for (sig_pars, _), par_sinks in chain(*itervalues(port_conns)):
for (sig_pars, transmission_pars), par_sinks in \
chain(*itervalues(port_conns)):
# Create a signal using these parameters
yield Signal(
source,
(ps.sink_object for ps in par_sinks), # Extract the sinks
sig_pars.keyspace,
sig_pars.weight
)
yield (Signal(source,
(ps.sink_object for ps in par_sinks), # Sinks
sig_pars.keyspace,
sig_pars.weight),
transmission_pars)


class OutputPort(enum.Enum):
Expand Down
54 changes: 2 additions & 52 deletions nengo_spinnaker/builder/node.py
Expand Up @@ -6,6 +6,8 @@

from nengo_spinnaker.builder.builder import ObjectPort, spec, Model
from nengo_spinnaker.builder.model import InputPort, OutputPort
from .connection import (PassthroughNodeTransmissionParameters,
NodeTransmissionParameters)
from nengo_spinnaker.operators import Filter, ValueSink, ValueSource
from nengo_spinnaker.utils.config import getconfig

Expand Down Expand Up @@ -294,58 +296,6 @@ def build_node_transmission_parameters(model, conn):
return PassthroughNodeTransmissionParameters(transform)


class PassthroughNodeTransmissionParameters(object):
"""Parameters describing connections which originate from pass through
Nodes.
"""
def __init__(self, transform):
# Store the parameters, copying the transform
self.transform = np.array(transform)

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equivalent if the same type
if type(self) is not type(other):
return False

# and the transforms are equivalent
if (self.transform.shape != other.transform.shape or
np.any(self.transform != other.transform)):
return False

return True


class NodeTransmissionParameters(PassthroughNodeTransmissionParameters):
"""Parameters describing connections which originate from Nodes."""
def __init__(self, pre_slice, function, transform):
# Store the parameters
super(NodeTransmissionParameters, self).__init__(transform)
self.pre_slice = pre_slice
self.function = function

def __hash__(self):
# Hash by ID
return hash(id(self))

def __eq__(self, other):
# Parent equivalence
if not super(NodeTransmissionParameters, self).__eq__(other):
return False

# Equivalent if the pre_slices are exactly the same
if self.pre_slice != other.pre_slice:
return False

# Equivalent if the functions are the same
if self.function is not other.function:
return False

return True


class InputNode(nengo.Node):
"""Node which queries the IO controller for the input to a Node from."""
def __init__(self, node, controller):
Expand Down
44 changes: 40 additions & 4 deletions tests/builder/test_builder.py
Expand Up @@ -594,6 +594,14 @@ def test_multiple_sink_vertices(self):
operator_b.make_vertices.return_value = \
netlistspec([vertex_b0, vertex_b1], load_fn_b)

# Create a third operator, which won't accept the signal
vertex_c = mock.Mock(name="vertex C")
vertex_c.accepts_signal.side_effect = lambda _, __: False

object_c = mock.Mock(name="object C")
operator_c = mock.Mock(name="operator C")
operator_c.make_vertices.return_value = netlistspec(vertex_c)

# Create a signal between the operators
keyspace = mock.Mock(name="keyspace")
keyspace.length = 32
Expand All @@ -603,14 +611,24 @@ def test_multiple_sink_vertices(self):
model = Model()
model.object_operators[object_a] = operator_a
model.object_operators[object_b] = operator_b
model.object_operators[object_c] = operator_c
model.connection_map.add_connection(
operator_a, None, signal_ab_parameters, None,
operator_b, None, None
)
model.connection_map.add_connection(
operator_a, None, signal_ab_parameters, None,
operator_c, None, None
)
netlist = model.make_netlist()

# Check that the "accepts_signal" method of vertex_c was called with
# reasonable arguments
assert vertex_c.accepts_signal.called

# Check that the netlist is as expected
assert set(netlist.vertices) == set([vertex_a, vertex_b0, vertex_b1])
assert set(netlist.vertices) == set(
[vertex_a, vertex_b0, vertex_b1, vertex_c])
assert len(netlist.nets) == 1
for net in netlist.nets:
assert net.sources == [vertex_a]
Expand All @@ -625,17 +643,29 @@ def test_multiple_source_vertices(self):
"""Test that each of the vertices associated with a source is correctly
included in the sources of a net.
"""
class MyVertexSlice(VertexSlice):
def __init__(self, *args, **kwargs):
super(MyVertexSlice, self).__init__(*args, **kwargs)
self.args = None

def transmits_signal(self, signal_parameters,
transmission_parameters):
self.args = (signal_parameters, transmission_parameters)
return False

# Create the first operator
vertex_a0 = VertexSlice(slice(0, 1))
vertex_a1 = VertexSlice(slice(1, 2))
vertex_a2 = MyVertexSlice(slice(2, 3))
load_fn_a = mock.Mock(name="load function A")
pre_fn_a = mock.Mock(name="pre function A")
post_fn_a = mock.Mock(name="post function A")

object_a = mock.Mock(name="object A")
operator_a = mock.Mock(name="operator A")
operator_a.make_vertices.return_value = \
netlistspec([vertex_a0, vertex_a1], load_fn_a, pre_fn_a, post_fn_a)
netlistspec([vertex_a0, vertex_a1, vertex_a2],
load_fn_a, pre_fn_a, post_fn_a)

# Create the second operator
vertex_b = Vertex()
Expand All @@ -662,10 +692,16 @@ def test_multiple_source_vertices(self):
netlist = model.make_netlist()

# Check that the netlist is as expected
assert set(netlist.vertices) == set([vertex_a0, vertex_a1, vertex_b])
assert set(netlist.vertices) == set([vertex_a0, vertex_a1,
vertex_a2, vertex_b])
assert len(netlist.nets) == 1
for net in netlist.nets:
assert net.sources == [vertex_a0, vertex_a1]
assert net.sinks == [vertex_b]

assert netlist.groups == [set([vertex_a0, vertex_a1])]
assert netlist.groups == [set([vertex_a0, vertex_a1, vertex_a2])]

# Check that `transmit_signal` was called correctly
sig, tp = vertex_a2.args
assert sig.keyspace is keyspace
assert tp is None

0 comments on commit ee48b82

Please sign in to comment.