Skip to content

Commit

Permalink
Store the signals which arrive at an operator
Browse files Browse the repository at this point in the history
Speed the build process by computing the list of signals which arrive at
all operators once rather than recomputing it every time it is
requested. This makes an unsurprisingly large difference in build time
for large models.
  • Loading branch information
mundya committed Jan 25, 2017
1 parent 21c7534 commit e64fbba
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 40 deletions.
19 changes: 13 additions & 6 deletions nengo_spinnaker/builder/builder.py
Expand Up @@ -142,10 +142,12 @@ def __init__(self, dt=0.001, machine_timestep=1000,
self.rngs = dict()
self.rng = None

# Model data
self.config = None
self.object_operators = dict()
self.extra_operators = list()
self.connection_map = model.ConnectionMap()
self._incoming_signals = None

if keyspaces is None:
keyspaces = KeyspaceContainer()
Expand Down Expand Up @@ -281,7 +283,7 @@ def make_probe(self, probe):
# Build
self._probe_builders[type(target_obj)](self, probe)

def get_signals_from_object(self, *args): # pragma : no cover
def get_signals_from_object(self, source_object):
"""Get the signals transmitted by a source object.
Returns
Expand All @@ -290,9 +292,9 @@ def get_signals_from_object(self, *args): # pragma : no cover
Dictionary mapping ports to lists of parameters for the signals
that originate from them.
"""
return self.connection_map.get_signals_from_object(*args)
return self.connection_map.get_signals_from_object(source_object)

def get_signals_to_object(self, *args): # pragma : no cover
def get_signals_to_object(self, sink_object):
"""Get the signals received by a sink object.
Returns
Expand All @@ -301,7 +303,12 @@ def get_signals_to_object(self, *args): # pragma : no cover
Dictionary mapping ports to the lists of objects specifying
incoming signals.
"""
return self.connection_map.get_signals_to_object(*args)
if self._incoming_signals is None:
# Get faster access to incoming signals from the connection map.
self._incoming_signals =\
self.connection_map.get_signals_to_all_objects()

return self._incoming_signals[sink_object]

def make_netlist(self, *args, **kwargs):
"""Convert the model into a netlist for simulating on SpiNNaker.
Expand Down Expand Up @@ -356,8 +363,8 @@ def make_netlist(self, *args, **kwargs):
else:
# Otherwise assume that all signals arriving at the operator
# must be uniquely identified.
incoming_all = itertools.chain(
*itervalues(self.connection_map.get_signals_to_object(op)))
incoming_all = itertools.chain(*itervalues(
self.get_signals_to_object(op)))
for (u, _), (v, _) in itertools.combinations(incoming_all, 2):
if u != v:
id_constraints[u].add(v)
Expand Down
58 changes: 26 additions & 32 deletions nengo_spinnaker/builder/model.py
@@ -1,6 +1,6 @@
"""Objects used to represent Nengo networks as instantiated on SpiNNaker.
"""
import collections
from collections import namedtuple, defaultdict
import enum
from .ports import EnsembleInputPort
from six import iteritems, itervalues, iterkeys
Expand Down Expand Up @@ -45,9 +45,9 @@ class ConnectionMap(object):
def __init__(self):
"""Create a new empty connection map."""
# Construct the connection map internal structure
self._connections = collections.defaultdict(
lambda: collections.defaultdict(
lambda: collections.defaultdict(list)
self._connections = defaultdict(
lambda: defaultdict(
lambda: defaultdict(list)
)
)

Expand Down Expand Up @@ -94,7 +94,7 @@ def get_signals_from_object(self, source_object):
Dictionary mapping ports to lists of parameters for the signals
that originate from them.
"""
signals = collections.defaultdict(list)
signals = defaultdict(list)

# For every port and list of (transmission pars, sinks) associated with
# it add the transmission parameters to the correct list of signals.
Expand All @@ -103,36 +103,32 @@ def get_signals_from_object(self, source_object):

return signals

def get_signals_to_object(self, sink_object):
"""Get the signals received by a sink object.
def get_signals_to_all_objects(self):
"""Get the signals received by all sink objects.
Returns
-------
{port : [ReceptionSpec, ...], ...}
Dictionary mapping ports to the lists of objects specifying
incoming signals.
{object: {port : [ReceptionSpec, ...], ...}, ...}
Dictionary mapping objects to mappings from ports to lists of
objects specifying incoming signals.
"""
signals = collections.defaultdict(list)
incoming_signals = defaultdict(lambda: defaultdict(list))

# For all connections we have reference to identify those which
# terminate at the given object. For those that do add a new entry to
# the signal dictionary.
for port_conns in itervalues(self._connections):
for conns in itervalues(port_conns):
for (sig_params, _), sinks in iteritems(conns):
# For each sink, if the sink object is the specified object
# then add signal to the list.
for sink in sinks:
if sink.sink_object is sink_object:
# This is the desired sink object, so remember the
# signal. First construction the reception
# specification.
signals[sink.port].append(
ReceptionSpec(sig_params,
sink.reception_parameters)
)
# This is the desired sink object, so remember the
# signal. First construction the reception
# specification.
incoming_signals[sink.sink_object][sink.port].append(
ReceptionSpec(sig_params,
sink.reception_parameters)
)

return signals
return incoming_signals

def get_signals(self):
"""Extract all the signals from the connection map.
Expand Down Expand Up @@ -204,8 +200,8 @@ def __ne__(self, b):
return not self == b


ReceptionParameters = collections.namedtuple("ReceptionParameters",
"filter, width, learning_rule")
ReceptionParameters = namedtuple("ReceptionParameters",
"filter, width, learning_rule")
"""Basic reception parameters that relate to the reception of a series of
multicast packets.
Expand All @@ -218,7 +214,7 @@ def __ne__(self, b):
"""


class _ParsSinksPair(collections.namedtuple("_PSP", "parameters, sinks")):
class _ParsSinksPair(namedtuple("_PSP", "parameters, sinks")):
"""Pair of transmission parameters and sink tuples."""
def __new__(cls, signal_parameters, sinks=list()):
# Copy the sinks list before calling __new__
Expand All @@ -227,15 +223,13 @@ def __new__(cls, signal_parameters, sinks=list()):
sinks)


_SinkPars = collections.namedtuple("_SinkPars", ["sink_object", "port",
"reception_parameters"])
_SinkPars = namedtuple("_SinkPars", ["sink_object", "port",
"reception_parameters"])
"""Collection of parameters for a sink."""


ReceptionSpec = collections.namedtuple(
"ReceptionSpec", ["signal_parameters",
"reception_parameters"]
)
ReceptionSpec = namedtuple("ReceptionSpec", ["signal_parameters",
"reception_parameters"])
"""Specification of an incoming connection.
Attributes
Expand Down
4 changes: 2 additions & 2 deletions tests/builder/test_model.py
Expand Up @@ -154,7 +154,7 @@ def test_get_signals_from_object(self):
assert (sp1, tp2) in sigs_b[sp_2]
assert (sp2, tp2) in sigs_b[sp_2]

def test_get_signals_to_object(self):
def test_get_signals_to_all_objects(self):
# Create two ports
sp_1 = mock.Mock(name="Sink Port 1")
sp_2 = mock.Mock(name="Sink Port 2")
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_get_signals_to_object(self):
cm.add_connection(None, None, tp1, None, sink_b, sp_1, rp)

# Get the signals to sink_a, check that they are as expected
sigs_a = cm.get_signals_to_object(sink_a)
sigs_a = cm.get_signals_to_all_objects()[sink_a]
assert len(sigs_a[sp_1]) == 2
seen_rps = []

Expand Down

0 comments on commit e64fbba

Please sign in to comment.