Skip to content

Commit

Permalink
Include code for removing passthrough Nodes
Browse files Browse the repository at this point in the history
Modified version of code from @tcstewar.
  • Loading branch information
mundya committed Jul 22, 2015
1 parent a1a3e08 commit abe9a8f
Show file tree
Hide file tree
Showing 7 changed files with 487 additions and 147 deletions.
85 changes: 58 additions & 27 deletions nengo_spinnaker/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from nengo_spinnaker.netlist import Net, Netlist
from nengo_spinnaker.utils import collections as collections_ext
from nengo_spinnaker.utils import passthrough_nodes as ptn_utils
from nengo_spinnaker.utils.keyspaces import KeyspaceContainer
from nengo_spinnaker.utils.probes import probe_target

BuiltConnection = collections.namedtuple(
"BuiltConnection", "decoders, eval_points, transform, solver_info"
Expand Down Expand Up @@ -208,32 +210,45 @@ def build(self, network, extra_builders={},
self._probe_builders.update(self.probe_builders)
self._probe_builders.update(extra_probe_builders)

# Build
self._build_network(network)

def _build_network(self, network):
# Get the seed for the network
self.seeds[network] = get_seed(network, np.random)
# Add connections from passthrough Nodes to all probes which target
# them.
ptn_utils.add_passthrough_node_to_probe_connections(network)

# Build all subnets
for subnet in network.networks:
self._build_network(subnet)
# Remove all of the passthrough Nodes, store a list of which
# connections should use the parameters of other connections.
network, self._prebuilt_connections = \
ptn_utils.remove_passthrough_nodes(network)

# Get the random number generator for the network
# Build all of the objects
self.seeds[network] = get_seed(network, np.random)
self.rngs[network] = np.random.RandomState(self.seeds[network])
self.rng = self.rngs[network]

# Build all objects
for obj in itertools.chain(network.ensembles, network.nodes):
for obj in itertools.chain(network.all_ensembles, network.all_nodes):
self.make_object(obj)

# Build all the connections
for connection in network.connections:
self.make_connection(connection)
# Build all "pre-built" connections
for conn in set(itervalues(self._prebuilt_connections)):
self._build_connection_params(conn)

# Build all passthrough Node probes
for probe in network.all_probes:
target_obj = probe_target(probe)
if (isinstance(target_obj, nengo.Node) and
target_obj.output is None):
self.make_probe(probe)

# Build all the probes
for probe in network.probes:
self.make_probe(probe)
# Build all connections
for conn in network.all_connections:
self.make_connection(conn)

# Build all remaining probes
for probe in network.all_probes:
target_obj = probe_target(probe)
if not (isinstance(target_obj, nengo.Node) and
target_obj.output is None):
self.make_probe(probe)

def make_object(self, obj):
"""Call an appropriate build function for the given object.
Expand All @@ -247,10 +262,31 @@ def make_connection(self, conn):
This method will build a connection and construct a new signal which
will be included in the model.
"""
self.seeds[conn] = get_seed(conn, self.rng)
self.params[conn] = \
self._connection_parameter_builders[type(conn.pre_obj)](self, conn)

self._build_connection_params(conn)
self._add_signal_for_connection(conn)

def _build_connection_params(self, conn):
"""Build the parameters for a connection."""
if conn in self._prebuilt_connections:
# If the connection is in the set of connections for which we use a
# pre-built connection as the source of the parameters we copy the
# parameters over and replace the transform.
self.seeds[conn] = self.seeds[self._prebuilt_connections[conn]]
self.params[conn] = BuiltConnection(
self.params[self._prebuilt_connections[conn]].decoders,
self.params[self._prebuilt_connections[conn]].eval_points,
conn.transform,
self.params[self._prebuilt_connections[conn]].solver_info
)
else:
# Otherwise we build the connection from scratch.
self.seeds[conn] = get_seed(conn, self.rng)
self.params[conn] = \
self._connection_parameter_builders[type(conn.pre_obj)](self,
conn)

def _add_signal_for_connection(self, conn):
"""Add a signal to simulate the connection to the model."""
# Get the source and sink specification, then make the signal provided
# that neither of specs is None.
source = self._source_getters[type(conn.pre_obj)](self, conn)
Expand All @@ -265,13 +301,8 @@ def make_probe(self, probe):
"""Call an appropriate build function for the given probe."""
self.seeds[probe] = get_seed(probe, self.rng)

# Get the target type
target_obj = probe.target
if isinstance(target_obj, nengo.base.ObjView):
target_obj = target_obj.obj

# Build
self._probe_builders[type(target_obj)](self, probe)
self._probe_builders[type(probe_target(probe))](self, probe)

def get_signals_connections_from_object(self, obj):
"""Get a dictionary mapping ports to signals to connections which
Expand Down
9 changes: 5 additions & 4 deletions nengo_spinnaker/builder/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def build_node_probe(self, model, probe):

# Create a new connection from the Node to the Probe and then get the
# model to build this.
seed = model.seeds[probe]
conn = nengo.Connection(probe.target, probe, synapse=probe.synapse,
seed=seed, add_to_container=False)
model.make_connection(conn)
if probe.target.output is not None:
seed = model.seeds[probe]
conn = nengo.Connection(probe.target, probe, synapse=probe.synapse,
seed=seed, add_to_container=False)
model.make_connection(conn)

def get_node_source(self, model, cn):
"""Get the source for a connection originating from a Node."""
Expand Down
161 changes: 161 additions & 0 deletions nengo_spinnaker/utils/passthrough_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Collection of tools necessary to modify networks to remove passthrough
Nodes.
"""
import nengo
from nengo.synapses import LinearFilter
from nengo.utils.builder import find_all_io, full_transform
import numpy as np
from numpy.polynomial.polynomial import polymul

from .probes import probe_target


def add_passthrough_node_to_probe_connections(network):
"""Adds new connections to the network to connect passthrough Nodes probes
which target them.
"""
# For all of the probes in the network, if any of them refer to a
# passthrough Node then add a connection from the passthrough Node to the
# probe.
for probe in network.all_probes:
obj = probe_target(probe)

# If the target is a passthrough Node then add a connection
if isinstance(obj, nengo.Node) and obj.output is None:
with network:
nengo.Connection(probe.target, probe, seed=probe.seed,
synapse=probe.synapse)


def combine_lti_synapses(a, b):
"""Combine two LTI filters."""
# Assert that both the synapses are LTI synapses
assert (isinstance(a, nengo.synapses.LinearFilter) and
isinstance(b, nengo.synapses.LinearFilter))

# Combine
return nengo.synapses.LinearFilter(polymul(a.num, b.num),
polymul(a.den, b.den))


def remove_passthrough_nodes(network):
"""Return a new network with all the passthrough Nodes removed and a
mapping from new connections to connections the decoders of which they can
use.
"""
# Get a list of the connections and a complete map of objects to inputs and
# outputs.
conns = list(network.all_connections)
inputs, outputs = find_all_io(conns)

# Prepare to create a map of which connections can just use the decoders
# from an earlier connection.
connection_decoders = dict()

# Create a new (flattened) network containing all elements from the
# original network apart from passthrough Nodes.
with nengo.Network() as m:
# Add all of the original ensembles
for ens in network.all_ensembles:
m.add(ens)

# Add all of the original probes
for probe in network.all_probes:
m.add(probe)

# For all the Nodes, if the Node is not a passthrough Node we add it as
# usual - otherwise we combine remove it and multiply together its
# input and output connections.
for node in network.all_nodes:
if node.output is not None:
with m:
m.add(node)
continue

# Remove the original connections associated with this passthrough
# Node from both the list of connections but also the lists
# associated with their pre- and post- objects.
conns_in = list(inputs[node])
conns_out = list(outputs[node])

for c in conns_in:
conns.remove(c)
outputs[c.pre_obj].remove(c)

for c in conns_out:
conns.remove(c)
inputs[c.post_obj].remove(c)

# For every outgoing connection
for out_conn in outputs[node]:
# For every incoming connection
for in_conn in inputs[node]:
use_pre_slice = in_conn.function is not None

# Create a new transform for the combined connections. If
# the transform is zero then we don't bother adding a new
# connection and instead move onto the next combination. If the
# in connection doesn't have a function then we include the
# pre-slice in the transform, otherwise we ignore it.
transform = np.dot(
full_transform(out_conn),
full_transform(in_conn, slice_pre=not use_pre_slice)
)

if np.all(transform == 0.0):
continue

# We determine if we can combine the synapses. If we can't
# we raise an error because we can't do anything at the
# moment.
if out_conn.synapse is None or in_conn.synapse is None:
# Trivial combination of synapses
new_synapse = out_conn.synapse or in_conn.synapse
elif (isinstance(in_conn.synapse, LinearFilter) and
isinstance(out_conn.synapse, LinearFilter)):
# Combination of LTI systems
print("Combining synapses of {} and {}".format(
in_conn, out_conn))
new_synapse = combine_lti_synapses(in_conn.synapse,
out_conn.synapse)
else:
# Can't combine these filters
raise NotImplementedError(
"Can't combine synapses of types {} and {}".format(
in_conn.synapse.__class__.__name__,
out_conn.synapse.__class__.__name__
)
)

# Create a new connection that combines the inputs and outputs.
new_c = nengo.Connection(
in_conn.pre if use_pre_slice else in_conn.pre_obj,
out_conn.post_obj,
function=in_conn.function,
synapse=new_synapse,
transform=transform,
add_to_container=False
)

# Add this connection to the list of connections to add to the
# model and the lists of outgoing and incoming connections for
# objects.
conns.append(new_c)
inputs[new_c.post_obj].append(new_c)
outputs[new_c.pre_obj].append(new_c)

# Determine which decoders should be used for this connection
# if the pre object is an ensemble.
if isinstance(in_conn.pre_obj, nengo.Ensemble):
x = in_conn
while x in connection_decoders:
x = connection_decoders[x]
connection_decoders[new_c] = x

# Add all the connections
with m:
for c in conns:
m.add(c)

# Return the new network and map of connections
return m, connection_decoders
11 changes: 11 additions & 0 deletions nengo_spinnaker/utils/probes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from nengo.base import ObjView


def probe_target(probe):
"""Get the target object of a probe."""
if isinstance(probe.target, ObjView):
# If the target is an object view then return the underlying object
return probe.target.obj
else:
# Otherwise return the target
return probe.target

0 comments on commit abe9a8f

Please sign in to comment.