Permalink
Find file
2c87f3d Nov 5, 2016
377 lines (353 sloc) 12.1 KB
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/**
* A node in a neural network. Each node has a state
* (total input, output, and their respectively derivatives) which changes
* after every forward and back propagation run.
*/
export class Node {
id: string;
/** List of input links. */
inputLinks: Link[] = [];
bias = 0.1;
/** List of output links. */
outputs: Link[] = [];
totalInput: number;
output: number;
/** Error derivative with respect to this node's output. */
outputDer = 0;
/** Error derivative with respect to this node's total input. */
inputDer = 0;
/**
* Accumulated error derivative with respect to this node's total input since
* the last update. This derivative equals dE/db where b is the node's
* bias term.
*/
accInputDer = 0;
/**
* Number of accumulated err. derivatives with respect to the total input
* since the last update.
*/
numAccumulatedDers = 0;
/** Activation function that takes total input and returns node's output */
activation: ActivationFunction;
/**
* Creates a new node with the provided id and activation function.
*/
constructor(id: string, activation: ActivationFunction, initZero?: boolean) {
this.id = id;
this.activation = activation;
if (initZero) {
this.bias = 0;
}
}
/** Recomputes the node's output and returns it. */
updateOutput(): number {
// Stores total input into the node.
this.totalInput = this.bias;
for (let j = 0; j < this.inputLinks.length; j++) {
let link = this.inputLinks[j];
this.totalInput += link.weight * link.source.output;
}
this.output = this.activation.output(this.totalInput);
return this.output;
}
}
/**
* An error function and its derivative.
*/
export interface ErrorFunction {
error: (output: number, target: number) => number;
der: (output: number, target: number) => number;
}
/** A node's activation function and its derivative. */
export interface ActivationFunction {
output: (input: number) => number;
der: (input: number) => number;
}
/** Function that computes a penalty cost for a given weight in the network. */
export interface RegularizationFunction {
output: (weight: number) => number;
der: (weight: number) => number;
}
/** Built-in error functions */
export class Errors {
public static SQUARE: ErrorFunction = {
error: (output: number, target: number) =>
0.5 * Math.pow(output - target, 2),
der: (output: number, target: number) => output - target
};
}
/** Polyfill for TANH */
(<any>Math).tanh = (<any>Math).tanh || function(x) {
if (x === Infinity) {
return 1;
} else if (x === -Infinity) {
return -1;
} else {
let e2x = Math.exp(2 * x);
return (e2x - 1) / (e2x + 1);
}
};
/** Built-in activation functions */
export class Activations {
public static TANH: ActivationFunction = {
output: x => (<any>Math).tanh(x),
der: x => {
let output = Activations.TANH.output(x);
return 1 - output * output;
}
};
public static RELU: ActivationFunction = {
output: x => Math.max(0, x),
der: x => x <= 0 ? 0 : 1
};
public static SIGMOID: ActivationFunction = {
output: x => 1 / (1 + Math.exp(-x)),
der: x => {
let output = Activations.SIGMOID.output(x);
return output * (1 - output);
}
};
public static LINEAR: ActivationFunction = {
output: x => x,
der: x => 1
};
}
/** Build-in regularization functions */
export class RegularizationFunction {
public static L1: RegularizationFunction = {
output: w => Math.abs(w),
der: w => w < 0 ? -1 : 1
};
public static L2: RegularizationFunction = {
output: w => 0.5 * w * w,
der: w => w
};
}
/**
* A link in a neural network. Each link has a weight and a source and
* destination node. Also it has an internal state (error derivative
* with respect to a particular input) which gets updated after
* a run of back propagation.
*/
export class Link {
id: string;
source: Node;
dest: Node;
weight = Math.random() - 0.5;
/** Error derivative with respect to this weight. */
errorDer = 0;
/** Accumulated error derivative since the last update. */
accErrorDer = 0;
/** Number of accumulated derivatives since the last update. */
numAccumulatedDers = 0;
regularization: RegularizationFunction;
/**
* Constructs a link in the neural network initialized with random weight.
*
* @param source The source node.
* @param dest The destination node.
* @param regularization The regularization function that computes the
* penalty for this weight. If null, there will be no regularization.
*/
constructor(source: Node, dest: Node,
regularization: RegularizationFunction, initZero?: boolean) {
this.id = source.id + "-" + dest.id;
this.source = source;
this.dest = dest;
this.regularization = regularization;
if (initZero) {
this.weight = 0;
}
}
}
/**
* Builds a neural network.
*
* @param networkShape The shape of the network. E.g. [1, 2, 3, 1] means
* the network will have one input node, 2 nodes in first hidden layer,
* 3 nodes in second hidden layer and 1 output node.
* @param activation The activation function of every hidden node.
* @param outputActivation The activation function for the output nodes.
* @param regularization The regularization function that computes a penalty
* for a given weight (parameter) in the network. If null, there will be
* no regularization.
* @param inputIds List of ids for the input nodes.
*/
export function buildNetwork(
networkShape: number[], activation: ActivationFunction,
outputActivation: ActivationFunction,
regularization: RegularizationFunction,
inputIds: string[], initZero?: boolean): Node[][] {
let numLayers = networkShape.length;
let id = 1;
/** List of layers, with each layer being a list of nodes. */
let network: Node[][] = [];
for (let layerIdx = 0; layerIdx < numLayers; layerIdx++) {
let isOutputLayer = layerIdx === numLayers - 1;
let isInputLayer = layerIdx === 0;
let currentLayer: Node[] = [];
network.push(currentLayer);
let numNodes = networkShape[layerIdx];
for (let i = 0; i < numNodes; i++) {
let nodeId = id.toString();
if (isInputLayer) {
nodeId = inputIds[i];
} else {
id++;
}
let node = new Node(nodeId,
isOutputLayer ? outputActivation : activation, initZero);
currentLayer.push(node);
if (layerIdx >= 1) {
// Add links from nodes in the previous layer to this node.
for (let j = 0; j < network[layerIdx - 1].length; j++) {
let prevNode = network[layerIdx - 1][j];
let link = new Link(prevNode, node, regularization, initZero);
prevNode.outputs.push(link);
node.inputLinks.push(link);
}
}
}
}
return network;
}
/**
* Runs a forward propagation of the provided input through the provided
* network. This method modifies the internal state of the network - the
* total input and output of each node in the network.
*
* @param network The neural network.
* @param inputs The input array. Its length should match the number of input
* nodes in the network.
* @return The final output of the network.
*/
export function forwardProp(network: Node[][], inputs: number[]): number {
let inputLayer = network[0];
if (inputs.length !== inputLayer.length) {
throw new Error("The number of inputs must match the number of nodes in" +
" the input layer");
}
// Update the input layer.
for (let i = 0; i < inputLayer.length; i++) {
let node = inputLayer[i];
node.output = inputs[i];
}
for (let layerIdx = 1; layerIdx < network.length; layerIdx++) {
let currentLayer = network[layerIdx];
// Update all the nodes in this layer.
for (let i = 0; i < currentLayer.length; i++) {
let node = currentLayer[i];
node.updateOutput();
}
}
return network[network.length - 1][0].output;
}
/**
* Runs a backward propagation using the provided target and the
* computed output of the previous call to forward propagation.
* This method modifies the internal state of the network - the error
* derivatives with respect to each node, and each weight
* in the network.
*/
export function backProp(network: Node[][], target: number,
errorFunc: ErrorFunction): void {
// The output node is a special case. We use the user-defined error
// function for the derivative.
let outputNode = network[network.length - 1][0];
outputNode.outputDer = errorFunc.der(outputNode.output, target);
// Go through the layers backwards.
for (let layerIdx = network.length - 1; layerIdx >= 1; layerIdx--) {
let currentLayer = network[layerIdx];
// Compute the error derivative of each node with respect to:
// 1) its total input
// 2) each of its input weights.
for (let i = 0; i < currentLayer.length; i++) {
let node = currentLayer[i];
node.inputDer = node.outputDer * node.activation.der(node.totalInput);
node.accInputDer += node.inputDer;
node.numAccumulatedDers++;
}
// Error derivative with respect to each weight coming into the node.
for (let i = 0; i < currentLayer.length; i++) {
let node = currentLayer[i];
for (let j = 0; j < node.inputLinks.length; j++) {
let link = node.inputLinks[j];
link.errorDer = node.inputDer * link.source.output;
link.accErrorDer += link.errorDer;
link.numAccumulatedDers++;
}
}
if (layerIdx === 1) {
continue;
}
let prevLayer = network[layerIdx - 1];
for (let i = 0; i < prevLayer.length; i++) {
let node = prevLayer[i];
// Compute the error derivative with respect to each node's output.
node.outputDer = 0;
for (let j = 0; j < node.outputs.length; j++) {
let output = node.outputs[j];
node.outputDer += output.weight * output.dest.inputDer;
}
}
}
}
/**
* Updates the weights of the network using the previously accumulated error
* derivatives.
*/
export function updateWeights(network: Node[][], learningRate: number,
regularizationRate: number) {
for (let layerIdx = 1; layerIdx < network.length; layerIdx++) {
let currentLayer = network[layerIdx];
for (let i = 0; i < currentLayer.length; i++) {
let node = currentLayer[i];
// Update the node's bias.
if (node.numAccumulatedDers > 0) {
node.bias -= learningRate * node.accInputDer / node.numAccumulatedDers;
node.accInputDer = 0;
node.numAccumulatedDers = 0;
}
// Update the weights coming into this node.
for (let j = 0; j < node.inputLinks.length; j++) {
let link = node.inputLinks[j];
let regulDer = link.regularization ?
link.regularization.der(link.weight) : 0;
if (link.numAccumulatedDers > 0) {
link.weight -= (learningRate / link.numAccumulatedDers) *
(link.accErrorDer + regularizationRate * regulDer);
link.accErrorDer = 0;
link.numAccumulatedDers = 0;
}
}
}
}
}
/** Iterates over every node in the network/ */
export function forEachNode(network: Node[][], ignoreInputs: boolean,
accessor: (node: Node) => any) {
for (let layerIdx = ignoreInputs ? 1 : 0;
layerIdx < network.length;
layerIdx++) {
let currentLayer = network[layerIdx];
for (let i = 0; i < currentLayer.length; i++) {
let node = currentLayer[i];
accessor(node);
}
}
}
/** Returns the output node in the network. */
export function getOutputNode(network: Node[][]) {
return network[network.length - 1][0];
}