Skip to content

Commit

Permalink
Change source/sink_getter return type.
Browse files Browse the repository at this point in the history
  • Loading branch information
mundya committed Apr 5, 2015
1 parent 6e2837a commit 94a1849
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 116 deletions.
10 changes: 5 additions & 5 deletions nengo_spinnaker/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_ensemble_sink(conn, irn):

irn.object_map[conn.post_obj].direct_input += np.dot(
full_transform(conn, slice_pre=False), val)
return None, {} # No connection should be made
return None # No connection should be made

# Otherwise connecting to an Ensemble is just like connecting to anything
# else.
Expand All @@ -83,13 +83,13 @@ def get_neurons_sink(conn, irn):
"""
if isinstance(conn.pre_obj, nengo.ensemble.Neurons):
# Neurons -> Neurons connection
return (NetAddress(irn.object_map[conn.post_obj.ensemble],
InputPort.neurons), {})
return ir.soss(NetAddress(irn.object_map[conn.post_obj.ensemble],
InputPort.neurons))
elif (conn.transform.ndim > 0 and
np.all(conn.transform == conn.transform[0])):
# This is a global inhibition connection and can be optimised
return (NetAddress(irn.object_map[conn.post_obj.ensemble],
InputPort.global_inhibition), {})
return ir.soss(NetAddress(irn.object_map[conn.post_obj.ensemble],
InputPort.global_inhibition), {})
raise NotImplementedError


Expand Down
63 changes: 35 additions & 28 deletions nengo_spinnaker/intermediate_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def build_deleted_node(node):
Each callable must accept as arguments the connection that the source is
for and a :py:class:`.IntermediateRepresentation` which it can use to get
information about other objects in the network. A callable must return a
2-tuple of (:py:class:`nengo_spinnaker.netlist.NetAddr`, **kwargs) where
accepted keys are currently "keyspace", "extra_objects",
"extra_connections" and "latching".
:py:class:`.SinkOrSourceSpecification` (see also :py:class:`.soss`).
For example, the standard source getter which just returns a source
indicating the source object and the standard output port is implemented
Expand All @@ -90,7 +88,7 @@ def build_deleted_node(node):
def get_source_standard(conn, ir_network):
source_obj = ir_network.object_map[conn.pre_obj]
source = NetAddr(source_obj, OutputPort.standard)
return source, {}
return soss(source)
"""

sink_getters = registerabledict()
Expand All @@ -107,7 +105,7 @@ def get_source_standard(conn, ir_network):
def get_sink_standard(conn, ir_network):
sink_obj = ir_network.object_map[conn.post_obj]
sink = NetAddr(sink_obj, InputPort.standard)
return sink, {}
return soss(sink)
"""

probe_builders = registerabledict()
Expand Down Expand Up @@ -287,6 +285,25 @@ def _filter_nets(self, f, key):
return nets


class SinkOrSourceSpecification(collections.namedtuple(
"SOSS",
"target extra_objects extra_nets keyspace latching weight"
)):
"""Specification for a source or sink as returned by a source or sink
getter.
"""
def __new__(cls, source_or_sink, extra_objects=list(),
extra_nets=list(), keyspace=None, latching=False, weight=None):
return super(SinkOrSourceSpecification, cls).__new__(
cls, source_or_sink, list(extra_objects), list(extra_nets),
keyspace, latching, weight
)


soss = SinkOrSourceSpecification
"""Quick reference to :py:class:`.SinkOrSourceSpecification`"""


class IntermediateNet(
collections.namedtuple("IntermediateNet",
"seed source sink keyspace latching weight")):
Expand Down Expand Up @@ -350,7 +367,7 @@ def get_source_standard(conn, irn):
conn : :py:class:`nengo.Connection`
irn : :py:class:`.IntermediateRepresentation`
"""
return (NetAddress(irn.object_map[conn.pre_obj], OutputPort.standard), {})
return soss(NetAddress(irn.object_map[conn.pre_obj], OutputPort.standard))


@IntermediateRepresentation.sink_getters.register(nengo.base.NengoObject)
Expand All @@ -362,7 +379,7 @@ def get_sink_standard(conn, irn):
conn : :py:class:`nengo.Connection`
irn : :py:class:`.IntermediateRepresentation`
"""
return (NetAddress(irn.object_map[conn.post_obj], InputPort.standard), {})
return soss(NetAddress(irn.object_map[conn.post_obj], InputPort.standard))


@IntermediateRepresentation.probe_builders.register(nengo.Node)
Expand Down Expand Up @@ -460,25 +477,25 @@ def _get_intermediate_endpoint(endpoint, getters, connection, irn):
if getattr(connection, "seed", None) is None else connection.seed)

# Get the source for the connection
source, source_extras = _get_intermediate_endpoint(
source_spec = _get_intermediate_endpoint(
_EndpointType.source, source_getters, connection, irn)

# If no source is specified then we abort the connection
if source is None:
if source_spec is None or source_spec.target is None:
return None, [], []

# Get the sink for the connection
sink, sink_extras = _get_intermediate_endpoint(
sink_spec = _get_intermediate_endpoint(
_EndpointType.sink, sink_getters, connection, irn)

# If no sink is specified then we abort the connection
if sink is None:
if sink_spec is None or sink_spec.target is None:
return None, [], []

# Resolve the keyspaces, allow either end to require a keyspace: if both
# ends require keyspaces then we fail.
source_ks = source_extras.pop("keyspace", None)
sink_ks = sink_extras.pop("keyspace", None)
source_ks = source_spec.keyspace
sink_ks = sink_spec.keyspace
if source_ks is None and sink_ks is None:
ks = None
elif source_ks is not None and sink_ks is None:
Expand All @@ -497,26 +514,16 @@ def _get_intermediate_endpoint(endpoint, getters, connection, irn):
# shouldn't be any case where there is a mismatch between the source and
# sink in this regard, or where a mismatch would result in incorrect
# behaviour.
latching = (source_extras.pop("latching", False) or
sink_extras.pop("latching", False))
latching = source_spec.latching or sink_spec.latching

# Combine the sets of extra objects and connections requested by the sink
# and sources.
extra_objs = (source_extras.pop("extra_objects", list()) +
sink_extras.pop("extra_objects", list()))
extra_conns = (source_extras.pop("extra_connections", list()) +
sink_extras.pop("extra_connections", list()))

# Complain if there were any keywords that we didn't understand.
for key in itertools.chain(*[six.iterkeys(s) for s in
[source_extras, sink_extras]]):
raise NotImplementedError(
"Unrecognised source/sink parameter {}".format(key)
)
extra_objs = source_spec.extra_objects + sink_spec.extra_objects
extra_conns = source_spec.extra_nets + sink_spec.extra_nets

# Build the new net
return (IntermediateNet(seed, source, sink, ks, latching, weight),
extra_objs, extra_conns)
return (IntermediateNet(seed, source_spec.target, sink_spec.target, ks,
latching, weight), extra_objs, extra_conns)


def _get_intermediate_probe(builders, probe, irn):
Expand Down
13 changes: 7 additions & 6 deletions nengo_spinnaker/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,19 @@ def get_probe_for_node(self, probe, seed, irn):
"""Get a probe object for the Node."""
# Get a source for the Node; then add a new probe object and a
# connection from the source to the probe.
source, extras = self.get_source_for_node(probe.target)
source_spec = self.get_source_for_node(probe.target)
probe_object = IntermediateObject(probe, seed)
probe_conn = IntermediateNet(
seed,
nl.NetAddress(source, nl.OutputPort.standard),
nl.NetAddress(source_spec.target, nl.OutputPort.standard),
nl.NetAddress(probe_object, nl.InputPort.standard),
keyspace=extras.pop("keyspace", None),
latching=extras.pop("latching", False)
keyspace=source_spec.keyspace,
latching=source_spec.latching,
weight=probe.size_in
)

objects = extras.pop("extra_objects", list()) + [source]
conns = extras.pop("extra_connections", list()) + [probe_conn]
objects = source_spec.extra_objects + [source_spec.target]
conns = source_spec.extra_nets + [probe_conn]

return probe_object, objects, conns

Expand Down
10 changes: 5 additions & 5 deletions nengo_spinnaker/tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_get_sink_standard(self):
irn = ir.IntermediateRepresentation(obj_map, {}, [], [])
assert (
ns_ens.get_ensemble_sink(c, irn) ==
(nl.NetAddress(obj_map[b], nl.InputPort.standard), {})
ir.soss(nl.NetAddress(obj_map[b], nl.InputPort.standard))
)

def test_get_sink_constant_node(self):
Expand All @@ -70,7 +70,7 @@ def test_get_sink_constant_node(self):

# We don't return a sink (None means "no connection required")
irn = ir.IntermediateRepresentation(obj_map, {}, [], [])
assert ns_ens.get_ensemble_sink(c, irn) == (None, {})
assert ns_ens.get_ensemble_sink(c, irn) is None

# But the Node values are added into the intermediate representation
# for the ensemble with the connection transform and function applied.
Expand All @@ -80,7 +80,7 @@ def test_get_sink_constant_node(self):

# For the next connection assert that we again don't add a connection
# and that the direct input is increased.
assert ns_ens.get_ensemble_sink(d, irn) == (None, {})
assert ns_ens.get_ensemble_sink(d, irn) is None
assert np.all(obj_map[b].direct_input ==
np.dot(full_transform(c, slice_pre=False),
c.function(a.output[c.pre_slice])) +
Expand All @@ -105,7 +105,7 @@ def test_neurons_to_neurons(self):
irn = ir.IntermediateRepresentation(obj_map, {}, [], [])
assert (
ns_ens.get_neurons_sink(c, irn) ==
(nl.NetAddress(obj_map[b], nl.InputPort.neurons), {})
ir.soss(nl.NetAddress(obj_map[b], nl.InputPort.neurons))
)

@pytest.mark.parametrize(
Expand All @@ -131,7 +131,7 @@ def test_global_inhibition(self, a):
irn = ir.IntermediateRepresentation(obj_map, {}, [], [])
assert (
ns_ens.get_neurons_sink(c, irn) ==
(nl.NetAddress(obj_map[b], nl.InputPort.global_inhibition), {})
ir.soss(nl.NetAddress(obj_map[b], nl.InputPort.global_inhibition))
)

def test_other(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import nengo
import numpy as np

from nengo_spinnaker import ensemble as ns_ens
from nengo_spinnaker import intermediate_representation as ir
from nengo_spinnaker import netlist as nl


def test_ensemble():
"""A functional test that looks at constant value Nodes, probes, global
inhibition connections.
"""
with nengo.Network() as net:
a = nengo.Node(output=0.5)
b = nengo.Ensemble(100, 1)
c = nengo.Ensemble(100, 5)
d = nengo.Ensemble(100, 4)

p_spikes = nengo.Probe(b.neurons)
p_value = nengo.Probe(d)

conn_ab = nengo.Connection(a, b) # Optimised out

# Global inhibition connections
conn_bc = nengo.Connection(b, c.neurons, transform=[[-1]]*c.n_neurons)

conn_cd = nengo.Connection(c[:4], d) # Normal

# Build the intermediate representation
irn = ir.IntermediateRepresentation.from_objs_conns_probes(
net.all_objects, net.connections, net.probes)

# Check that b, c, and d are intermediate ensembles
assert isinstance(irn.object_map[b], ns_ens.IntermediateEnsemble)
assert irn.object_map[b].local_probes == [p_spikes]
assert irn.object_map[b].direct_input == 0.5

assert isinstance(irn.object_map[c], ns_ens.IntermediateEnsemble)
assert irn.object_map[c].local_probes == list()
assert np.all(irn.object_map[c].direct_input == np.zeros(5))

assert isinstance(irn.object_map[d], ns_ens.IntermediateEnsemble)
assert irn.object_map[d].local_probes == list()
assert np.all(irn.object_map[d].direct_input == np.zeros(4))

# Check that conn a->b was optimised out
assert conn_ab not in irn.connection_map

# Check that conn b->c was identified as global inhibition
assert (irn.connection_map[conn_bc].sink.port is
nl.InputPort.global_inhibition)

# Check that conn c->d was left as normal
assert (irn.connection_map[conn_cd].sink.port is
nl.InputPort.standard)

# The probe on d should be in the object map
assert p_value in irn.object_map

# There should be a connection d->p_value
conn = irn.extra_connections[0]
assert conn.source.object is irn.object_map[d]
assert conn.source.port is nl.OutputPort.standard
assert conn.sink.object is irn.object_map[p_value]
assert conn.sink.port is nl.InputPort.standard

0 comments on commit 94a1849

Please sign in to comment.