forked from liquidcarrot/carrot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MemoryLayer.ts
70 lines (62 loc) 路 1.88 KB
/
MemoryLayer.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import { ActivationType, Identitiy, Logistic } from "activations";
import { Node } from "../../Node";
import { Layer } from "../Layer";
import { NodeType } from "../../../enums/NodeType";
import { ConnectionType } from "../../../enums/ConnectionType";
/**
* Memory layer
*/
export class MemoryLayer extends Layer {
constructor(
outputSize: number,
options: {
/**
* The activation type for the output nodes of this layer.
*/
activation?: ActivationType;
/**
* The size of the memory.
*/
memorySize?: number;
} = {}
) {
super(outputSize);
for (let i = 0; i < outputSize; i++) {
this.inputNodes.add(new Node(NodeType.HIDDEN));
}
let prevNodes: Node[] = Array.from(this.inputNodes);
const nodes: Node[] = [];
for (let i = 0; i < (options.memorySize ?? 1); i++) {
const block: Node[] = [];
for (let j = 0; j < outputSize; j++) {
const node: Node = new Node(NodeType.HIDDEN);
node.squash = Identitiy;
node.bias = 0;
block.push(node);
}
this.connections.push(...Layer.connect(prevNodes, block, ConnectionType.ONE_TO_ONE));
nodes.push(...block);
prevNodes = block;
}
this.nodes.push(...Array.from(this.inputNodes));
this.nodes.push(...nodes.reverse());
prevNodes.forEach((node) => this.outputNodes.add(node));
this.outputNodes.forEach((node) => (node.squash = options.activation ?? Logistic));
}
/**
* Checks if a given connection type is allowed on this layer.
*
* @return Is this connection type allowed?
*/
public connectionTypeisAllowed(): boolean {
return true;
}
/**
* Gets the default connection type for a incoming connection to this layer.
*
* @returns the default incoming connection
*/
public getDefaultIncomingConnectionType(): ConnectionType {
return ConnectionType.ALL_TO_ALL;
}
}