-
Notifications
You must be signed in to change notification settings - Fork 7
/
connection.py
119 lines (89 loc) · 3.64 KB
/
connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import nengo
import numpy as np
from .builder import Model, ObjectPort, spec
from .model import ReceptionParameters, InputPort, OutputPort
@Model.source_getters.register(nengo.base.NengoObject)
def generic_source_getter(model, conn):
obj = model.object_operators[conn.pre_obj]
return spec(ObjectPort(obj, OutputPort.standard))
@Model.sink_getters.register(nengo.base.NengoObject)
def generic_sink_getter(model, conn):
obj = model.object_operators[conn.post_obj]
return spec(ObjectPort(obj, InputPort.standard))
@Model.reception_parameter_builders.register(nengo.base.NengoObject)
@Model.reception_parameter_builders.register(nengo.ensemble.Neurons)
def build_generic_reception_params(model, conn):
"""Build parameters necessary for receiving packets that simulate this
connection.
"""
# Just extract the synapse from the connection.
return ReceptionParameters(conn.synapse, conn.post_obj.size_in)
class EnsembleTransmissionParameters(object):
"""Transmission parameters for a connection originating at an Ensemble.
Attributes
----------
decoders : array
Decoders to use for the connection.
"""
def __init__(self, decoders, transform):
# Copy the decoders
self.untransformed_decoders = np.array(decoders)
self.transform = np.array(transform)
# Compute and store the transformed decoders
self.decoders = np.dot(transform, decoders.T).T
# Make the arrays read-only
self.untransformed_decoders.flags['WRITEABLE'] = False
self.transform.flags['WRITEABLE'] = False
self.decoders.flags['WRITEABLE'] = False
def __ne__(self, other):
return not (self == other)
def __eq__(self, other):
# Equal iff. the objects are of the same type
if type(self) is not type(other):
return False
# Equal iff. the decoders are the same shape
if self.decoders.shape != other.decoders.shape:
return False
# Equal iff. the decoder values are the same
if np.any(self.decoders != other.decoders):
return False
return True
class PassthroughNodeTransmissionParameters(object):
"""Parameters describing connections which originate from pass through
Nodes.
"""
def __init__(self, transform):
# Store the parameters, copying the transform
self.transform = np.array(transform)
def __ne__(self, other):
return not (self == other)
def __eq__(self, other):
# Equivalent if the same type
if type(self) is not type(other):
return False
# and the transforms are equivalent
if (self.transform.shape != other.transform.shape or
np.any(self.transform != other.transform)):
return False
return True
class NodeTransmissionParameters(PassthroughNodeTransmissionParameters):
"""Parameters describing connections which originate from Nodes."""
def __init__(self, pre_slice, function, transform):
# Store the parameters
super(NodeTransmissionParameters, self).__init__(transform)
self.pre_slice = pre_slice
self.function = function
def __hash__(self):
# Hash by ID
return hash(id(self))
def __eq__(self, other):
# Parent equivalence
if not super(NodeTransmissionParameters, self).__eq__(other):
return False
# Equivalent if the pre_slices are exactly the same
if self.pre_slice != other.pre_slice:
return False
# Equivalent if the functions are the same
if self.function is not other.function:
return False
return True