Skip to content

Commit

Permalink
WIP Remove passthrough Nodes at the signal level
Browse files Browse the repository at this point in the history
To do:
 - [ ] Get a signal to send to two sinks if this would have been the
       effect of the passthrough Node.
  • Loading branch information
mundya committed Jun 20, 2015
1 parent 29076c0 commit e0ff2c1
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 2 deletions.
1 change: 1 addition & 0 deletions nengo_spinnaker/simulator.py
Expand Up @@ -93,6 +93,7 @@ def __init__(self, network, dt=0.001, period=10.0):
self.model = Model(dt, decoder_cache=get_default_decoder_cache())
self.model.build(network, **builder_kwargs)
model_optimisations.remove_childless_filters(self.model)
model_optimisations.remove_passthrough_nodes(self.model)
logger.info("Build took {:.3f} seconds".format(time.time() -
start_build))

Expand Down
110 changes: 110 additions & 0 deletions nengo_spinnaker/utils/model.py
Expand Up @@ -3,6 +3,9 @@
from __future__ import absolute_import

import itertools
import nengo
from nengo.utils.builder import full_transform
import numpy as np
from six import iteritems, itervalues

from nengo_spinnaker.operators import Filter
Expand Down Expand Up @@ -82,3 +85,110 @@ def remove_childless_filters(model):
# signals which have no sinks, we should remove these as it will allow
# us to find further filters with no children.
remove_sinkless_signals(model)


def remove_passthrough_nodes(model):
"""Remove passthrough Nodes from the model."""
# Find and remove each passthrough Node in turn. To do this we take all of
# the connections feeding in to the passthrough Node and combine them with
# the connections leaving the passthrough Node. We also pair the sources
# and sinks of the signals associated with the connections.
for obj, op in iteritems(model.object_operators):
# If the object is not a Node, or not a passthrough Node then we move
# on to the next object.
# NOTE: These lines are tested but continue is optimised out.
if (not isinstance(obj, nengo.Node) or # pragma: no branch
obj.output is not None):
continue # pragma: no cover

# The object is a passthrough Node, so we get the list of all incoming
# signals and connections and all outgoing signals and connections. If
# there is anything odd we don't bother dealing with this passthrough
# Node and move onto the next.
incoming_connections_signals = list()
valid = True
for sig, conns in itertools.chain(
*(iteritems(sc) for sc in itervalues(
model.get_signals_connections_to_object(op)))):
# We're happy ONLY in cases that there is a pairing of ONE signal
# to ONE connection and there is only ONE sink for the signal.
if len(conns) != 1 or len(sig.sinks) != 1:
valid = False
break

# Just add the pair of signal and connection to the list of
# incoming signals and connections.
incoming_connections_signals.append((sig, conns[0]))

if not valid:
continue

outgoing_connections_signals = list()
for sig, conns in itertools.chain(
*(iteritems(sc) for sc in itervalues(
model.get_signals_connections_from_object(op)))):
# We're happy ONLY in cases that there is a pairing of ONE signal
# to ONE connection and there is only ONE sink for the signal.
if len(conns) != 1 or len(sig.sinks) != 1:
valid = False
break

# Just add the pair of signal and connection to the list of
# incoming signals and connections.
outgoing_connections_signals.append((sig, conns[0]))

if not valid:
continue

# Try to combine all connections and signals. Multiply the transforms
# together to find out what the final transform would be; if this is
# zero (a not uncommon occurrence) then don't bother adding the new
# signal/connection. If any of the combinations are not possible then
# abort this process (e.g., combining synapses).
new_connections = list()
for (in_sig, in_conn) in incoming_connections_signals:
for (out_sig, out_conn) in outgoing_connections_signals:
# If one of the synapses is None then we can continue
if (in_conn.synapse is not None and
out_conn.synapse is not None):
valid = False
break

# If the combination of the transforms is zero then we don't
# bother adding a new signal or connection to the model.
new_transform = np.dot(
full_transform(out_conn),
full_transform(in_conn)
)
if np.all(new_transform == 0):
continue

# Create a new connection to add this to the list of
# connections to add.
new_connections.append(
nengo.Connection(
in_conn.pre_obj, out_conn.post_obj,
function=in_conn.function,
synapse=in_conn.synapse or out_conn.synapse,
transform=new_transform,
add_to_container=False
)
)

if not valid:
break

if not valid:
continue

# Remove all the incoming and outgoing connections and signals and then
# build all the new connections.
for (in_sig, in_conn) in incoming_connections_signals:
model.connections_signals.pop(in_conn)

for (out_sig, out_conn) in outgoing_connections_signals:
model.connections_signals.pop(out_conn)

# Build all the new connections
for connection in new_connections:
model.make_connection(connection)
169 changes: 167 additions & 2 deletions tests/utils/test_utils_model.py
@@ -1,9 +1,15 @@
import mock
import nengo
import pytest
from six import iteritems

from nengo_spinnaker.builder.builder import Model, ObjectPort, Signal
from nengo_spinnaker.builder.builder import (Model, ObjectPort, Signal,
InputPort)
from nengo_spinnaker.operators import Filter
from nengo_spinnaker.utils.model import (remove_childless_filters,
remove_sinkless_signals)
remove_sinkless_signals,
remove_passthrough_nodes)
import nengo_spinnaker


def test_remove_sinkless_signals():
Expand Down Expand Up @@ -109,3 +115,162 @@ def test_remove_childless_filters():
assert model.connections_signals == {}
assert model.extra_signals == [s1, s2, s3]
assert [s.obj for s in s3.sinks] == [o2]


def test_remove_passnodes():
"""Test that passnodes can be correctly removed from a network.
We create and build the following model and then ensure that the pass node
is removed related connections are moved as required.
E1 -->\ /---> E5
\ /
E2 -->\ /---> E6
PN
E3 -->/ | \---> E7
/ |\
E4 -->/ | \---> E8
\
\----> P1
Should become:
E1 -+----------> E5
\
E2 ---+--------> E6
\
E3 ----+-------> E7
\
E4 ------+-----> E8
\
\---> P1
"""
# Create the Nengo model
with nengo.Network() as network:
# Create the Ensembles
e1_4 = [nengo.Ensemble(100, 1) for _ in range(4)]
e5_8 = [nengo.Ensemble(100, 1) for _ in range(4)]

# Add the passnode
pn = nengo.Node(size_in=4)

# And the probe
p1 = nengo.Probe(pn)

# Finally, add the connections
connections = list()
connections.extend(nengo.Connection(e1_4[n], pn[n], synapse=None) for
n in range(4))
connections.extend(nengo.Connection(pn[n], e5_8[n], synapse=None) for
n in range(4))

# Build this into a model
nioc = nengo_spinnaker.builder.node.NodeIOController()
model = Model()
model.build(network, **nioc.builder_kwargs)

# Apply the passnode remover to the model
remove_passthrough_nodes(model)

# Check that this was indeed done...
# None of the original connections should be in the model
# connections->signals mapping, but they should remain in the parameters
# dictionary.
for conn in connections:
assert conn not in model.connections_signals
assert conn in model.params

# E5..8 should have one input each, this should be a signal from their
# partner in E1..4 and should be paired with a signal with an appropriate
# transform. The signal should be of weight 1.
for i, ens in enumerate(e5_8):
# Get the object simulating the ensemble
lif = model.object_operators[ens]

# Check the incoming signals
sigs = model.get_signals_connections_to_object(lif)[InputPort.standard]
assert len(sigs) == 1

for sig, conns in iteritems(sigs):
# Check the connection is sane
assert len(conns) == 1
conn = conns[0]
assert conn.pre_obj is e1_4[i]

# Check that the signal is sane
assert sig.weight == 1
assert sig.source.obj is model.object_operators[e1_4[i]]

# P1 should receive many signals, one per pre-ensemble, and these should be
# associated with similar connections.
probe_op = model.object_operators[p1]
sigs =\
model.get_signals_connections_to_object(probe_op)[InputPort.standard]
assert len(sigs) == 4
for sig, conns in iteritems(sigs):
# Check the connection is sane
assert len(conns) == 1
conn = conns[0]
assert conn.pre_obj in e1_4

# Check that the signal is sane
# assert sig.weight == 1
assert sig.source.obj is model.object_operators[conn.pre_obj]


@pytest.mark.parametrize("signal_in", ["input", "output"])
def test_remove_passthrough_nodes_aborts_multisink(signal_in):
"""Check that removing passthrough Nodes does nothing in the case that
there is an input signal with multiple sinks.
"""
# Create a Nengo model
with nengo.Network() as network:
a = nengo.Ensemble(500, 5)
b = nengo.Node(size_in=5)
c = nengo.Ensemble(100, 5)

a_b = nengo.Connection(a, b)
b_c = nengo.Connection(b, c)

# Build this into a model
nioc = nengo_spinnaker.builder.node.NodeIOController()
model = Model()
model.build(network, **nioc.builder_kwargs)

if signal_in == "input":
# Add an extra sink to the signal associated with a->b
model.connections_signals[a_b].sinks.append(ObjectPort(None, None))
else:
# Add an extra sink to the signal associated with b->c
model.connections_signals[b_c].sinks.append(ObjectPort(None, None))

# Check that removing passthrough Nodes does nothing
remove_passthrough_nodes(model)
assert b in model.object_operators
assert a_b in model.connections_signals
assert b_c in model.connections_signals


def test_remove_passthrough_nodes_aborts_merge_filters():
"""Check that removing passthrough Nodes does nothing in the case that
there is a synapse on both the incoming and the outgoing connection.
"""
# Create a Nengo model
with nengo.Network() as network:
a = nengo.Ensemble(500, 5)
b = nengo.Node(size_in=5)
c = nengo.Ensemble(100, 5)

a_b = nengo.Connection(a, b, synapse=0.005)
b_c = nengo.Connection(b, c, synapse=0.001)

# Build this into a model
nioc = nengo_spinnaker.builder.node.NodeIOController()
model = Model()
model.build(network, **nioc.builder_kwargs)

# Check that removing passthrough Nodes does nothing
remove_passthrough_nodes(model)
assert b in model.object_operators
assert a_b in model.connections_signals
assert b_c in model.connections_signals

0 comments on commit e0ff2c1

Please sign in to comment.