Skip to content

Commit

Permalink
Refactor to avoid circular import
Browse files Browse the repository at this point in the history
Migrates `*TransmissionParameters` into
`nengo_spinnaker/builder/connection.py`
  • Loading branch information
mundya committed Oct 9, 2015
1 parent 4d36ba3 commit 8cadf19
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 83 deletions.
84 changes: 84 additions & 0 deletions nengo_spinnaker/builder/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import nengo
import numpy as np
from .builder import Model, ObjectPort, spec
from .model import ReceptionParameters, InputPort, OutputPort

Expand All @@ -23,3 +24,86 @@ def build_generic_reception_params(model, conn):
"""
# Just extract the synapse from the connection.
return ReceptionParameters(conn.synapse, conn.post_obj.size_in)


class EnsembleTransmissionParameters(object):
"""Transmission parameters for a connection originating at an Ensemble.
Attributes
----------
decoders : array
Decoders to use for the connection (including the transform).
"""
def __init__(self, decoders):
# Copy the decoders
self.decoders = np.array(decoders)

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equal iff. the objects are of the same type
if type(self) is not type(other):
return False

# Equal iff. the decoders are the same shape
if self.decoders.shape != other.decoders.shape:
return False

# Equal iff. the decoder values are the same
if np.any(self.decoders != other.decoders):
return False

return True


class PassthroughNodeTransmissionParameters(object):
"""Parameters describing connections which originate from pass through
Nodes.
"""
def __init__(self, transform):
# Store the parameters, copying the transform
self.transform = np.array(transform)

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equivalent if the same type
if type(self) is not type(other):
return False

# and the transforms are equivalent
if (self.transform.shape != other.transform.shape or
np.any(self.transform != other.transform)):
return False

return True


class NodeTransmissionParameters(PassthroughNodeTransmissionParameters):
"""Parameters describing connections which originate from Nodes."""
def __init__(self, pre_slice, function, transform):
# Store the parameters
super(NodeTransmissionParameters, self).__init__(transform)
self.pre_slice = pre_slice
self.function = function

def __hash__(self):
# Hash by ID
return hash(id(self))

def __eq__(self, other):
# Parent equivalence
if not super(NodeTransmissionParameters, self).__eq__(other):
return False

# Equivalent if the pre_slices are exactly the same
if self.pre_slice != other.pre_slice:
return False

# Equivalent if the functions are the same
if self.function is not other.function:
return False

return True
32 changes: 1 addition & 31 deletions nengo_spinnaker/builder/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from .builder import BuiltConnection, Model, ObjectPort, spec
from .connection import EnsembleTransmissionParameters
from .model import InputPort
from .ports import EnsembleInputPort
from .. import operators
Expand Down Expand Up @@ -137,37 +138,6 @@ def build_lif(model, ens):
model.object_operators[ens] = operators.EnsembleLIF(ens)


class EnsembleTransmissionParameters(object):
"""Transmission parameters for a connection originating at an Ensemble.
Attributes
----------
decoders : array
Decoders to use for the connection (including the transform).
"""
def __init__(self, decoders):
# Copy the decoders
self.decoders = np.array(decoders)

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equal iff. the objects are of the same type
if type(self) is not type(other):
return False

# Equal iff. the decoders are the same shape
if self.decoders.shape != other.decoders.shape:
return False

# Equal iff. the decoder values are the same
if np.any(self.decoders != other.decoders):
return False

return True


@Model.transmission_parameter_builders.register(nengo.Ensemble)
def build_from_ensemble_connection(model, conn):
"""Build the parameters object for a connection from an Ensemble."""
Expand Down
55 changes: 3 additions & 52 deletions nengo_spinnaker/builder/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from nengo_spinnaker.builder.builder import ObjectPort, spec, Model
from nengo_spinnaker.builder.model import InputPort, OutputPort
from nengo_spinnaker.builder.connection import (
PassthroughNodeTransmissionParameters,
NodeTransmissionParameters)
from nengo_spinnaker.operators import Filter, ValueSink, ValueSource
from nengo_spinnaker.utils.config import getconfig

Expand Down Expand Up @@ -293,58 +296,6 @@ def build_node_transmission_parameters(model, conn):
return PassthroughNodeTransmissionParameters(transform)


class PassthroughNodeTransmissionParameters(object):
"""Parameters describing connections which originate from pass through
Nodes.
"""
def __init__(self, transform):
# Store the parameters, copying the transform
self.transform = np.array(transform)

def __ne__(self, other):
return not (self == other)

def __eq__(self, other):
# Equivalent if the same type
if type(self) is not type(other):
return False

# and the transforms are equivalent
if (self.transform.shape != other.transform.shape or
np.any(self.transform != other.transform)):
return False

return True


class NodeTransmissionParameters(PassthroughNodeTransmissionParameters):
"""Parameters describing connections which originate from Nodes."""
def __init__(self, pre_slice, function, transform):
# Store the parameters
super(NodeTransmissionParameters, self).__init__(transform)
self.pre_slice = pre_slice
self.function = function

def __hash__(self):
# Hash by ID
return hash(id(self))

def __eq__(self, other):
# Parent equivalence
if not super(NodeTransmissionParameters, self).__eq__(other):
return False

# Equivalent if the pre_slices are exactly the same
if self.pre_slice != other.pre_slice:
return False

# Equivalent if the functions are the same
if self.function is not other.function:
return False

return True


class InputNode(nengo.Node):
"""Node which queries the IO controller for the input to a Node from."""
def __init__(self, node, controller):
Expand Down

0 comments on commit 8cadf19

Please sign in to comment.