Skip to content

Commit

Permalink
Expand connections_signals for multiple signals
Browse files Browse the repository at this point in the history
Allow a connection to be simulated multiple signals as well as the
existing ability of a signal to simulate multiple connections. This is
necessary for building merge trees where a connection may need to be
simulated by multiple signals (o1 -> merge node ... -> o2).

Add a `all_signals` iterator to `Model` which yields each signal only
once.
  • Loading branch information
mundya committed Jun 10, 2015
1 parent ad6e637 commit 9b27aef
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 35 deletions.
96 changes: 85 additions & 11 deletions nengo_spinnaker/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,40 @@ def __init__(self, dt=0.001, machine_timestep=1000,
self.machine_timestep = machine_timestep
self.decoder_cache = decoder_cache

# Map Nengo objects to their built parameters (e.g. Connections ->
# Decoders, ...)
self.params = dict()

# Map of Nengo objects to the seeds they are assigned and to their
# associated random number generators.
self.seeds = dict()
self.rngs = dict()
self.rng = None

# A reference to the config which may modify how the network is built.
self.config = None

# A map of Nengo objects to the "operators" which will handle their
# simulation on SpiNNaker.
self.object_operators = dict()

# A list of extra "operators"
self.extra_operators = list()
self.connections_signals = dict()

# Map of Nengo connections to lists of signals which lead to their
# simulation on SpiNNaker. This is set up so that a single signal may
# be used to simulate multiple connections (e.g., the same packets will
# end up being routed to multiple sinks, like for spikes) and a single
# connection may be simulated by multiple signals, this will occur in
# merge trees where at least two signals are needed per connection
# (source -> merge node -> sink).
self.connections_signals = collections.defaultdict(list)

# Additional signals
self.extra_signals = list()

# The keyspaces dictionary automatically assigns separate regions of
# the overall keyspace to various users (e.g., external devices).
if keyspaces is None:
keyspaces = KeyspaceContainer()
self.keyspaces = keyspaces
Expand Down Expand Up @@ -258,8 +281,8 @@ def make_connection(self, conn):

if source is not None and sink is not None:
assert conn not in self.connections_signals
self.connections_signals[conn] = _make_signal(self, conn,
source, sink)
self.connections_signals[conn].append(
_make_signal(self, conn, source, sink))

def make_probe(self, probe):
"""Call an appropriate build function for the given probe."""
Expand All @@ -277,13 +300,26 @@ def get_signals_connections_from_object(self, obj):
"""Get a dictionary mapping ports to signals to connections which
originate from a given intermediate object.
"""
# {port: {signal: [connection, ...], ...}, ...}
ports_sigs_conns = collections.defaultdict(
lambda: collections.defaultdict(collections_ext.noneignoringlist)
)

for (conn, signal) in itertools.chain(
iteritems(self.connections_signals),
((None, s) for s in self.extra_signals)):
# Create an iterator which will yield pairs of connections and signals.
# Each pair of connection and list of signals from
# `connections_signals` is expanded into pairs of connections and
# signals. Each signal from `extra_signals` is paired with None to
# indicate that it is unrelated to a connection.
conn_sigs = itertools.chain(
((c, s) for c, ss in iteritems(self.connections_signals)
for s in ss),
((None, s) for s in self.extra_signals)
)

# Now iterate over this, if the source object of any of the signals is
# the object we care about then we add the port, signal and connection
# to the dictionary that we're building.
for (conn, signal) in conn_sigs:
if signal.source.obj is obj:
ports_sigs_conns[signal.source.port][signal].append(conn)

Expand All @@ -293,19 +329,58 @@ def get_signals_connections_to_object(self, obj):
"""Get a dictionary mapping ports to signals to connections which
terminate at a given intermediate object.
"""
# {port: {signal: [connection, ...], ...}, ...}
ports_sigs_conns = collections.defaultdict(
lambda: collections.defaultdict(collections_ext.noneignoringlist)
)

for (conn, signal) in itertools.chain(
iteritems(self.connections_signals),
((None, s) for s in self.extra_signals)):
# Create an iterator which will yield pairs of connections and signals.
# Each pair of connection and list of signals from
# `connections_signals` is expanded into pairs of connections and
# signals. Each signal from `extra_signals` is paired with None to
# indicate that it is unrelated to a connection.
conn_sigs = itertools.chain(
((c, s) for c, ss in iteritems(self.connections_signals)
for s in ss),
((None, s) for s in self.extra_signals)
)

# Now iterate over this, if any of the sinks a signal is the object we
# care about then we add the port, signal and connection to the
# dictionary that we are building.
for (conn, signal) in conn_sigs:
for sink in signal.sinks:
if sink.obj is obj:
ports_sigs_conns[sink.port][signal].append(conn)

return ports_sigs_conns

def all_signals(self):
"""Iterator of all signals within a model.
Yields
------
Signal
All of the signals from a model, once each.
"""
# Create an iterator over all of the signals in the model. Each signal
# may appear multiple times in this iteration.
all_signals = itertools.chain(self.extra_signals,
*itervalues(self.connections_signals))

# To avoid yielding the same signal twice keep track of which signals
# we've seen and only yield signals which are not in this set.
seen_signals = set()
for sig in all_signals:
# If we've already seen the signal then skip to the next.
if sig in seen_signals:
continue

# Otherwise add the signal to the set of signals we've seen and
# then yield.
seen_signals.add(sig)
yield sig

def make_netlist(self, *args, **kwargs):
"""Convert the model into a netlist for simulating on SpiNNaker.
Expand Down Expand Up @@ -345,8 +420,7 @@ def make_netlist(self, *args, **kwargs):

# Construct nets from the signals
nets = list()
for signal in itertools.chain(itervalues(self.connections_signals),
self.extra_signals):
for signal in self.all_signals():
# Get the source and sink vertices
sources = operator_vertices[signal.source.obj]
if not isinstance(sources, collections.Iterable):
Expand Down
10 changes: 6 additions & 4 deletions nengo_spinnaker/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import absolute_import

import itertools
from six import iteritems, itervalues
from six import iteritems

from nengo_spinnaker.operators import Filter

Expand All @@ -14,7 +14,8 @@ def remove_sinkless_signals(model):
"""
# Create a list of signals to remove by iterating through the signals which
# are related to connections and finding any with no sinks.
sinkless_signals = [(c, s) for c, s in iteritems(model.connections_signals)
sinkless_signals = [(c, s) for c, ss in
iteritems(model.connections_signals) for s in ss
if len(s.sinks) == 0]

# Now remove all sinkless signals
Expand Down Expand Up @@ -66,8 +67,9 @@ def remove_childless_filters(model):
for obj, filt in childless_filters:
# Remove the filter from the list of sinks of each of the signals
# which target it.
for sig in itertools.chain(itervalues(model.connections_signals),
model.extra_signals):
for sig in model.all_signals():
# Prepare and remove sinks which target the object we're
# removing.
sinks = [s for s in sig.sinks if s.obj is filt]
for sink in sinks:
sig.sinks.remove(sink)
Expand Down
90 changes: 76 additions & 14 deletions tests/builder/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def connection_builder_fn(m, c):
assert model.params[connection] is built_connection

# Assert that the signal exists
signal = model.connections_signals[connection]
signal = model.connections_signals[connection][0]
assert signal.source is source
assert signal.sinks == [sink]

Expand Down Expand Up @@ -669,11 +669,11 @@ def test_get_signals_and_connections_starting_from(self):
# Create a model holding all of these items
model = Model()
model.connections_signals = {
conn_ab1: sig_ab1,
conn_ab2: sig_ab2,
conn_ba1: sig_ba1,
conn_ba2: sig_ba2,
conn_ba3: sig_ba2,
conn_ab1: [sig_ab1],
conn_ab2: [sig_ab2],
conn_ba1: [sig_ba1],
conn_ba2: [sig_ba2],
conn_ba3: [sig_ba2],
}
model.extra_signals = [sig_ab3, sig_ab4, sig_ba3]

Expand Down Expand Up @@ -739,11 +739,11 @@ def test_get_signals_and_connections_terminating_at(self):
# Create a model holding all of these items
model = Model()
model.connections_signals = {
conn_ab1: sig_ab1,
conn_ab2: sig_ab2,
conn_ba1: sig_ba1,
conn_ba2: sig_ba2,
conn_ba3: sig_ba2,
conn_ab1: [sig_ab1],
conn_ab2: [sig_ab2],
conn_ba1: [sig_ba1],
conn_ba2: [sig_ba2],
conn_ba3: [sig_ba2],
}
model.extra_signals = [sig_ab3]

Expand Down Expand Up @@ -773,6 +773,44 @@ def test_get_signals_and_connections_terminating_at(self):
},
}

def test_not_all_originate_or_terminate(self):
"""Test that we get valid signal and connection pairs when not all
signals associated with a connection terminate at the object of
interest.
|~~~~~c1~~|
v v
o1 -s1--> o2 -s2--> o3
`c1` is simulated by `s1` and `s2` but only `s2` terminates at `o3`.
"""
# Construct all of the objects
o1 = mock.Mock(name="o1")
o2 = mock.Mock(name="o2")
o3 = mock.Mock(name="o3")

# Create C1
c1 = mock.Mock()

# Create the signals
s1 = Signal(ObjectPort(o1, None), ObjectPort(o2, None), None)
s2 = Signal(ObjectPort(o2, None), ObjectPort(o3, None), None)

# Construct the model
model = Model()
model.connections_signals[c1] = [s1, s2]

# Check that we can grab the signals and connections originating from
# o1.
from_o1 = model.get_signals_connections_from_object(o1)
assert len(from_o1) == 1
assert from_o1 == {None: {s1: [c1]}}

# Check that we can grab the signals and connections terminating at o3.
to_o3 = model.get_signals_connections_to_object(o3)
assert len(to_o3) == 1
assert to_o3 == {None: {s2: [c1]}}


class TestMakeNetlist(object):
"""Test production of netlists from operators and signals."""
Expand Down Expand Up @@ -812,7 +850,7 @@ def test_single_vertices(self):
model = Model()
model.object_operators[object_a] = operator_a
model.object_operators[object_b] = operator_b
model.connections_signals[None] = signal_ab
model.connections_signals[None] = [signal_ab]
netlist = model.make_netlist()

# Check that the make_vertices functions were called
Expand Down Expand Up @@ -921,7 +959,7 @@ def test_multiple_sink_vertices(self):
model = Model()
model.object_operators[object_a] = operator_a
model.object_operators[object_b] = operator_b
model.connections_signals[None] = signal_ab
model.connections_signals[None] = [signal_ab]
netlist = model.make_netlist()

# Check that the netlist is as expected
Expand Down Expand Up @@ -969,7 +1007,7 @@ def test_multiple_source_vertices(self):
model = Model()
model.object_operators[object_a] = operator_a
model.object_operators[object_b] = operator_b
model.connections_signals[None] = signal_ab
model.connections_signals[None] = [signal_ab]
netlist = model.make_netlist()

# Check that the netlist is as expected
Expand All @@ -980,3 +1018,27 @@ def test_multiple_source_vertices(self):
assert net.sinks == [vertex_b]

assert netlist.groups == [set([vertex_a0, vertex_a1])]


def test_all_signals():
"""Test that a list of signals can be extracted from a model and that the
same signal is not presented twice in the iterator.
"""
# Create 2 signals
s1 = Signal(ObjectPort(mock.Mock(), None), ObjectPort(mock.Mock(), None),
None)
s2 = Signal(ObjectPort(mock.Mock(), None), ObjectPort(mock.Mock(), None),
None)

# Create a connection
c1 = mock.Mock()

# Create the model
model = Model()
model.connections_signals[c1] = [s1, s2]
model.extra_signals.append(s1)

# Get all the signals, there should only be two
sigs = list(model.all_signals())
assert len(sigs) == 2
assert set(sigs) == set([s1, s2])
4 changes: 2 additions & 2 deletions tests/operators/test_sdp_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_make_vertices(self):
sig_b = Signal(ObjectPort(sdp_rx, OutputPort.standard), None, ks_b)

model.connections_signals = {
conn_a: sig_a,
conn_b: sig_b,
conn_a: [sig_a],
conn_b: [sig_b],
}

# Make the vertices
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_utils_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def test_remove_sinkless_signals():
# Create the model
model = Model()
model.extra_operators = [o1, o2]
model.connections_signals = {c1: cs1, c2: cs2}
model.connections_signals = {c1: [cs1], c2: [cs2]}
model.extra_signals = [ss1, ss2]

# Remove sinkless signals
remove_sinkless_signals(model)

# Check that signals were removed as necessary
assert model.connections_signals == {c1: cs1}
assert model.connections_signals == {c1: [cs1]}
assert model.extra_signals == [ss1]


Expand Down Expand Up @@ -95,8 +95,8 @@ def test_remove_childless_filters():
}
model.extra_operators = [f1, f2, f4, f5]
model.connections_signals = {
cs4: s4,
cs5: s5,
cs4: [s4],
cs5: [s5],
}
model.extra_signals = [s1, s2, s3, s6]

Expand Down

0 comments on commit 9b27aef

Please sign in to comment.