Skip to content

Commit

Permalink
Improve matrix storage
Browse files Browse the repository at this point in the history
Rewrite the transmission parameters to reduce code complexity and save
memory.
  • Loading branch information
mundya committed Nov 22, 2016
1 parent e54768a commit 0226dfd
Show file tree
Hide file tree
Showing 20 changed files with 1,107 additions and 631 deletions.
105 changes: 0 additions & 105 deletions nengo_spinnaker/builder/connection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
import nengo
import numpy as np

from .builder import Model, ObjectPort, spec
from .model import ReceptionParameters, InputPort, OutputPort

try:
from xxhash import xxh64 as fast_hash
except ImportError:
from hashlib import md5 as fast_hash


@Model.source_getters.register(nengo.base.NengoObject)
def generic_source_getter(model, conn):
Expand All @@ -32,101 +25,3 @@ def build_generic_reception_params(model, conn):
# Just extract the synapse from the connection.
return ReceptionParameters(conn.synapse, conn.post_obj.size_in,
conn.learning_rule)


class TransmissionParameters(object):
"""Parameters describing generic connections."""
def __init__(self, transform):
self.transform = np.array(transform, order='C')
self.transform.flags['WRITEABLE'] = False

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equivalent if the same type
if type(self) is not type(other):
return False

# and the transforms are equivalent
if not np.array_equal(self.transform, other.transform):
return False

return True

@property
def _hashables(self):
return (type(self), fast_hash(self.transform).hexdigest())

def __hash__(self):
return hash(self._hashables)


class EnsembleTransmissionParameters(TransmissionParameters):
"""Transmission parameters for a connection originating at an Ensemble.
Attributes
----------
transform : array
Decoders to use for the connection (n_dims x n_neurons)
"""
def __init__(self, transform, learning_rule):
super(EnsembleTransmissionParameters, self).__init__(transform)

# Cache learning rule
self.learning_rule = learning_rule

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equal iff. neither connection has a learning rule
if self.learning_rule is not None or other.learning_rule is not None:
return False

return super(EnsembleTransmissionParameters, self).__eq__(other)

def __hash__(self):
return hash(
super(EnsembleTransmissionParameters, self)._hashables +
(self.learning_rule, )
)


class PassthroughNodeTransmissionParameters(TransmissionParameters):
"""Parameters describing connections which originate from pass through
Nodes.
"""


class NodeTransmissionParameters(PassthroughNodeTransmissionParameters):
"""Parameters describing connections which originate from Nodes."""
def __init__(self, pre_slice, function, transform):
# Store the parameters
super(NodeTransmissionParameters, self).__init__(transform)
self.pre_slice = pre_slice
self.function = function

def __hash__(self):
# Hash by ID
return hash(id(self))

def __eq__(self, other):
# Parent equivalence
if not super(NodeTransmissionParameters, self).__eq__(other):
return False

# Equivalent if the pre_slices are exactly the same
if self.pre_slice != other.pre_slice:
return False

# Equivalent if the functions are the same
if self.function is not other.function:
return False

return True

@property
def _hashables(self):
return (super(NodeTransmissionParameters, self)._hashables +
(self.function, ))
32 changes: 8 additions & 24 deletions nengo_spinnaker/builder/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from .builder import BuiltConnection, Model, ObjectPort, spec
from .connection import EnsembleTransmissionParameters
from .transmission_parameters import EnsembleTransmissionParameters
from .model import InputPort, OutputPort
from .ports import EnsembleInputPort, EnsembleOutputPort
from .. import operators
Expand Down Expand Up @@ -121,20 +121,7 @@ def get_learning_rule_sink(model, connection):
def get_neurons_sink(model, connection):
"""Get the sink for connections into the neurons of an ensemble."""
ens = model.object_operators[connection.post_obj.ensemble]

if (connection.transform.ndim == 2 and
np.all(connection.transform[1:] == connection.transform[0])):
# Connections from non-neurons to Neurons where the transform delivers
# the same value to all neurons are treated as global inhibition
# connection.
# Return a signal to the correct port.
return spec(ObjectPort(ens, EnsembleInputPort.global_inhibition))
else:
# - Connections from Neurons can go straight to the Neurons.
# - Otherwise we don't support arbitrary connections into neurons, but
# we allow them because they may be optimised out later when we come
# to remove passthrough nodes.
return spec(ObjectPort(ens, EnsembleInputPort.neurons))
return spec(ObjectPort(ens, EnsembleInputPort.neurons))


ensemble_builders = collections_ext.registerabledict()
Expand Down Expand Up @@ -243,7 +230,7 @@ def build_from_ensemble_connection(model, conn):
rng = np.random.RandomState(model.seeds[conn])

# Get the transform
transform = full_transform(conn, slice_pre=False, allow_scalars=False)
transform = conn.transform

# Solve for the decoders
eval_points, decoders, solver_info = build_decoders(model, conn, rng)
Expand All @@ -254,14 +241,11 @@ def build_from_ensemble_connection(model, conn):
transform=transform,
solver_info=solver_info)

# Modify the transform if this is a global inhibition connection
if (isinstance(conn.post_obj, nengo.ensemble.Neurons) and
np.all(transform[0, :] == transform[1:, :])):
transform = np.array([transform[0]])

transform = np.dot(transform, decoders.T)

return EnsembleTransmissionParameters(transform, conn.learning_rule)
return EnsembleTransmissionParameters(decoders.T,
conn.post_obj.size_in,
conn.post_slice,
conn.learning_rule,
transform)


@Model.transmission_parameter_builders.register(nengo.ensemble.Neurons)
Expand Down
9 changes: 9 additions & 0 deletions nengo_spinnaker/builder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
import collections
import enum
from .ports import EnsembleInputPort
from six import iteritems, itervalues, iterkeys


Expand Down Expand Up @@ -67,6 +68,14 @@ def add_connection(self, source_object, source_port, signal_parameters,
Sink-specific parameters of how the received packets are to be
treated.
"""
# Swap out the connection for a global inhibition connection if
# possible.
if (sink_port is EnsembleInputPort.neurons and
transmission_parameters.supports_global_inhibition):
sink_port = EnsembleInputPort.global_inhibition
transmission_parameters = \
transmission_parameters.as_global_inhibition_connection

# Combine the signal parameters with the transmission parameters
# (These represent the signal and can be hashed)
pars = (signal_parameters, transmission_parameters)
Expand Down
46 changes: 23 additions & 23 deletions nengo_spinnaker/builder/node.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import nengo
from nengo.processes import Process
from nengo.utils.builder import full_transform
import numpy as np
import threading

from .connection import (PassthroughNodeTransmissionParameters,
NodeTransmissionParameters)
from .transmission_parameters import (
PassthroughNodeTransmissionParameters,
NodeTransmissionParameters)
from nengo_spinnaker.builder.builder import ObjectPort, spec, Model
from nengo_spinnaker.builder.model import InputPort, OutputPort
from nengo_spinnaker.operators import Filter, ValueSink, ValueSource
Expand Down Expand Up @@ -283,27 +283,27 @@ def close(self):
def build_node_transmission_parameters(model, conn):
"""Build transmission parameters for a connection originating at a Node."""
if conn.pre_obj.output is not None:
# Connection is not from a passthrough Node
# Get the full transform, not including the pre_slice
transform = full_transform(conn, slice_pre=False, allow_scalars=False)
else:
# Connection is from a passthrough Node
# Get the full transform
transform = full_transform(conn, allow_scalars=False)

# If the connection is to neurons and the transform is equivalent in every
# row we treat it as a global inhibition connection and shrink it down to
# one row.
if (isinstance(conn.post_obj, nengo.ensemble.Neurons) and
np.all(transform[0, :] == transform[1:, :])):
# Reduce the size of the transform
transform = np.array([transform[0]])

if conn.pre_obj.output is not None:
return NodeTransmissionParameters(conn.pre_slice, conn.function,
transform)
if conn.function is None:
size_in = conn.pre_obj.size_out
else:
size_in = conn.size_mid

return NodeTransmissionParameters(
size_in=size_in,
size_out=conn.post_obj.size_in,
transform=conn.transform,
slice_out=conn.post_slice,
pre_slice=conn.pre_slice,
function=conn.function
)
else:
return PassthroughNodeTransmissionParameters(transform)
return PassthroughNodeTransmissionParameters(
size_in=conn.pre_obj.size_out,
size_out=conn.post_obj.size_in,
transform=conn.transform,
slice_in=conn.pre_slice,
slice_out=conn.post_slice
)


class InputNode(nengo.Node):
Expand Down

0 comments on commit 0226dfd

Please sign in to comment.